00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #include <shogun/multiclass/tree/BalancedConditionalProbabilityTree.h> 00012 00013 using namespace shogun; 00014 00015 CBalancedConditionalProbabilityTree::CBalancedConditionalProbabilityTree() 00016 :m_alpha(0.4) 00017 { 00018 SG_ADD(&m_alpha, "m_alpha", "Trade-off parameter of tree balance", MS_NOT_AVAILABLE); 00019 } 00020 00021 void CBalancedConditionalProbabilityTree::set_alpha(float64_t alpha) 00022 { 00023 if (alpha < 0 || alpha > 1) 00024 SG_ERROR("expect 0 <= alpha <= 1, but got %g\n", alpha); 00025 m_alpha = alpha; 00026 } 00027 00028 bool CBalancedConditionalProbabilityTree::which_subtree(node_t *node, SGVector<float32_t> ex) 00029 { 00030 float64_t pred = predict_node(ex, node); 00031 float64_t depth_left = tree_depth(node->left()); 00032 float64_t depth_right = tree_depth(node->right()); 00033 00034 float64_t cnt_left = CMath::pow(2.0, depth_left); 00035 float64_t cnt_right = CMath::pow(2.0, depth_right); 00036 00037 float64_t obj_val = (1-m_alpha) * 2 * (pred-0.5) + m_alpha * CMath::log2(cnt_left/cnt_right); 00038 00039 if (obj_val > 0) 00040 return false; // go right 00041 return true; // go left 00042 } 00043 00044 int32_t CBalancedConditionalProbabilityTree::tree_depth(node_t *node) 00045 { 00046 int32_t depth = 0; 00047 while (node != NULL) 00048 { 00049 depth++; 00050 node = node->left(); 00051 } 00052 00053 return depth; 00054 }