BalancedConditionalProbabilityTree.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <shogun/multiclass/tree/BalancedConditionalProbabilityTree.h>
00012 
00013 using namespace shogun;
00014 
00015 CBalancedConditionalProbabilityTree::CBalancedConditionalProbabilityTree()
00016     :m_alpha(0.4)
00017 {
00018     SG_ADD(&m_alpha, "m_alpha", "Trade-off parameter of tree balance", MS_NOT_AVAILABLE);
00019 }
00020 
00021 void CBalancedConditionalProbabilityTree::set_alpha(float64_t alpha)
00022 {
00023     if (alpha < 0 || alpha > 1)
00024         SG_ERROR("expect 0 <= alpha <= 1, but got %g\n", alpha);
00025     m_alpha = alpha;
00026 }
00027 
00028 bool CBalancedConditionalProbabilityTree::which_subtree(node_t *node, SGVector<float32_t> ex)
00029 {
00030     float64_t pred = predict_node(ex, node);
00031     float64_t depth_left = tree_depth(node->left());
00032     float64_t depth_right = tree_depth(node->right());
00033 
00034     float64_t cnt_left = CMath::pow(2.0, depth_left);
00035     float64_t cnt_right = CMath::pow(2.0, depth_right);
00036 
00037     float64_t obj_val = (1-m_alpha) * 2 * (pred-0.5) + m_alpha * CMath::log2(cnt_left/cnt_right);
00038 
00039     if (obj_val > 0)
00040         return false; // go right
00041     return true; // go left
00042 }
00043 
00044 int32_t CBalancedConditionalProbabilityTree::tree_depth(node_t *node)
00045 {
00046     int32_t depth = 0;
00047     while (node != NULL)
00048     {
00049         depth++;
00050         node = node->left();
00051     }
00052 
00053     return depth;
00054 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation