RelaxedTree.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <limits>
00012 #include <queue>
00013 #include <algorithm>
00014 #include <functional>
00015 
00016 #include <shogun/labels/BinaryLabels.h>
00017 #include <shogun/multiclass/tree/RelaxedTreeUtil.h>
00018 #include <shogun/multiclass/tree/RelaxedTree.h>
00019 #include <shogun/kernel/GaussianKernel.h>
00020 
00021 
00022 using namespace shogun;
00023 
00024 CRelaxedTree::CRelaxedTree()
00025     :m_max_num_iter(3), m_A(0.5), m_B(5), m_svm_C(1), m_svm_epsilon(0.001), 
00026     m_kernel(NULL), m_feats(NULL), m_machine_for_confusion_matrix(NULL), m_num_classes(0)
00027 {
00028     SG_ADD(&m_max_num_iter, "m_max_num_iter", "max number of iterations in alternating optimization", MS_NOT_AVAILABLE);
00029     SG_ADD(&m_svm_C, "m_svm_C", "C for svm", MS_AVAILABLE);
00030     SG_ADD(&m_A, "m_A", "parameter A", MS_AVAILABLE);
00031     SG_ADD(&m_B, "m_B", "parameter B", MS_AVAILABLE);
00032     SG_ADD(&m_svm_epsilon, "m_svm_epsilon", "epsilon for svm", MS_AVAILABLE);
00033 }
00034 
00035 CRelaxedTree::~CRelaxedTree()
00036 {
00037     SG_UNREF(m_kernel);
00038     SG_UNREF(m_feats);
00039     SG_UNREF(m_machine_for_confusion_matrix);
00040 }
00041 
00042 CMulticlassLabels* CRelaxedTree::apply_multiclass(CFeatures* data)
00043 {
00044     if (data != NULL)
00045     {
00046         CDenseFeatures<float64_t> *feats = dynamic_cast<CDenseFeatures<float64_t>*>(data);
00047         REQUIRE(feats != NULL, ("Require non-NULL dense features of float64_t\n"));
00048         set_features(feats);
00049     }
00050 
00051     // init kernels for all sub-machines
00052     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00053     {
00054         CSVM *machine = (CSVM*)m_machines->get_element(i);
00055         CKernel *kernel = machine->get_kernel();
00056         CFeatures* lhs = kernel->get_lhs();
00057         kernel->init(lhs, m_feats);
00058         SG_UNREF(machine);
00059         SG_UNREF(kernel);
00060         SG_UNREF(lhs);
00061     }
00062 
00063     CMulticlassLabels *lab = new CMulticlassLabels(m_feats->get_num_vectors());
00064     SG_REF(lab);
00065     for (int32_t i=0; i < lab->get_num_labels(); ++i)
00066     {
00067         lab->set_int_label(i, int32_t(apply_one(i)));
00068     }
00069 
00070     return lab;
00071 }
00072 
00073 float64_t CRelaxedTree::apply_one(int32_t idx)
00074 {
00075     node_t *node = m_root;
00076     int32_t klass = -1;
00077     while (node != NULL)
00078     {
00079         CSVM *svm = (CSVM *)m_machines->get_element(node->machine());
00080         float64_t result = svm->apply_one(idx);
00081 
00082         if (result < 0)
00083         {
00084             // go left
00085             if (node->left()) // has left subtree
00086             {
00087                 node = node->left();
00088             }
00089             else // stop here
00090             {
00091                 for (int32_t i=0; i < node->data.mu.vlen; ++i)
00092                 {
00093                     if (node->data.mu[i] <= 0 && node->data.mu[i] > -2)
00094                     {
00095                         klass = i;
00096                         break;
00097                     }
00098                 }
00099                 node = NULL;
00100             }
00101         }
00102         else
00103         {
00104             // go right
00105             if (node->right())
00106             {
00107                 node = node->right();
00108             }
00109             else
00110             {
00111                 for (int32_t i=0; i <node->data.mu.vlen; ++i)
00112                 {
00113                     if (node->data.mu[i] >= 0)
00114                     {
00115                         klass = i;
00116                         break;
00117                     }
00118                 }
00119                 node = NULL;
00120             }
00121         }
00122 
00123         SG_UNREF(svm);
00124     }
00125 
00126     return klass;
00127 }
00128 
00129 bool CRelaxedTree::train_machine(CFeatures* data)
00130 {
00131     if (m_machine_for_confusion_matrix == NULL)
00132         SG_ERROR("Call set_machine_for_confusion_matrix before training\n");
00133     if (m_kernel == NULL)
00134         SG_ERROR("assign a valid kernel before training\n");
00135 
00136     if (data)
00137     {
00138         CDenseFeatures<float64_t> *feats = dynamic_cast<CDenseFeatures<float64_t>*>(data);
00139         if (feats == NULL)
00140             SG_ERROR("Require non-NULL dense features of float64_t\n");
00141         set_features(feats);
00142     }
00143 
00144     CMulticlassLabels *lab = dynamic_cast<CMulticlassLabels *>(m_labels);
00145 
00146     RelaxedTreeUtil util;
00147     SGMatrix<float64_t> conf_mat = util.estimate_confusion_matrix(m_machine_for_confusion_matrix,
00148             m_feats, lab, m_num_classes);
00149 
00150     // train root
00151     SGVector<int32_t> classes(m_num_classes);
00152 
00153     for (int32_t i=0; i < m_num_classes; ++i)
00154         classes[i] = i;
00155 
00156     SG_UNREF(m_root);
00157     m_root = train_node(conf_mat, classes);
00158 
00159     std::queue<node_t *> node_q;
00160     node_q.push(m_root);
00161 
00162     while (node_q.size() != 0)
00163     {
00164         node_t *node = node_q.front();
00165 
00166         // left node
00167         SGVector <int32_t> left_classes(m_num_classes);
00168         int32_t k=0;
00169         for (int32_t i=0; i < m_num_classes; ++i)
00170         {
00171             // active classes are labeled as -1 or 0
00172             // -2 indicate classes that are already pruned away
00173             if (node->data.mu[i] <= 0 && node->data.mu[i] > -2)
00174                 left_classes[k++] = i;
00175         }
00176 
00177         left_classes.vlen = k;
00178 
00179         if (left_classes.vlen >= 2)
00180         {
00181             node_t *left_node = train_node(conf_mat, left_classes);
00182             node->left(left_node);
00183             node_q.push(left_node);
00184         }
00185 
00186         // right node
00187         SGVector <int32_t> right_classes(m_num_classes);
00188         k=0;
00189         for (int32_t i=0; i < m_num_classes; ++i)
00190         {
00191             // active classes are labeled as 0 or +1
00192             if (node->data.mu[i] >= 0)
00193                 right_classes[k++] = i;
00194         }
00195 
00196         right_classes.vlen = k;
00197 
00198         if (right_classes.vlen >= 2)
00199         {
00200             node_t *right_node = train_node(conf_mat, right_classes);
00201             node->right(right_node);
00202             node_q.push(right_node);
00203         }
00204 
00205         node_q.pop();
00206     }
00207 
00208     //m_root->debug_print(RelaxedTreeNodeData::print_data);
00209 
00210     return true;
00211 }
00212 
00213 CRelaxedTree::node_t *CRelaxedTree::train_node(const SGMatrix<float64_t> &conf_mat, SGVector<int32_t> classes)
00214 {
00215     SGVector<int32_t> best_mu;
00216     CSVM *best_svm = NULL;
00217     float64_t best_score = std::numeric_limits<float64_t>::max();
00218 
00219     std::vector<CRelaxedTree::entry_t> mu_init = init_node(conf_mat, classes);
00220     for (std::vector<CRelaxedTree::entry_t>::const_iterator it = mu_init.begin(); it != mu_init.end(); ++it)
00221     {
00222         CSVM *svm = new CLibSVM();
00223         SG_REF(svm);
00224         svm->set_store_model_features(true);
00225         
00226         SGVector<int32_t> mu = train_node_with_initialization(*it, classes, svm);
00227         float64_t score = compute_score(mu, svm);
00228 
00229         if (score < best_score)
00230         {
00231             best_score = score;
00232             best_mu = mu;
00233             SG_UNREF(best_svm);
00234             best_svm = svm;
00235         }
00236         else
00237         {
00238             SG_UNREF(svm);
00239         }
00240     }
00241 
00242     node_t *node = new node_t;
00243     SG_REF(node);
00244 
00245     m_machines->push_back(best_svm);
00246     node->machine(m_machines->get_num_elements()-1);
00247 
00248     SGVector<int32_t> long_mu(m_num_classes);
00249     std::fill(&long_mu[0], &long_mu[m_num_classes], -2);
00250     for (int32_t i=0; i < best_mu.vlen; ++i)
00251     {
00252         if (best_mu[i] == 1)
00253             long_mu[classes[i]] = 1;
00254         else if (best_mu[i] == -1)
00255             long_mu[classes[i]] = -1;
00256         else if (best_mu[i] == 0)
00257             long_mu[classes[i]] = 0;
00258     }
00259 
00260     node->data.mu = long_mu;
00261     return node;
00262 }
00263 
00264 float64_t CRelaxedTree::compute_score(SGVector<int32_t> mu, CSVM *svm)
00265 {
00266     float64_t num_pos=0, num_neg=0;
00267     for (int32_t i=0; i < mu.vlen; ++i)
00268     {
00269         if (mu[i] == 1)
00270             num_pos++;
00271         else if (mu[i] == -1)
00272             num_neg++;
00273     }
00274 
00275     int32_t totalSV = svm->get_support_vectors().vlen;
00276     float64_t score = num_neg/(num_neg+num_pos) * totalSV/num_pos + 
00277         num_pos/(num_neg+num_pos)*totalSV/num_neg;
00278     return score;
00279 }
00280 
00281 SGVector<int32_t> CRelaxedTree::train_node_with_initialization(const CRelaxedTree::entry_t &mu_entry, SGVector<int32_t> classes, CSVM *svm)
00282 {
00283     SGVector<int32_t> mu(classes.vlen), prev_mu(classes.vlen);
00284     mu.zero();
00285     mu[mu_entry.first.first] = 1;
00286     mu[mu_entry.first.second] = -1;
00287 
00288     SGVector<int32_t> long_mu(m_num_classes);
00289     svm->set_C(m_svm_C, m_svm_C);
00290     svm->set_epsilon(m_svm_epsilon);
00291 
00292     for (int32_t iiter=0; iiter < m_max_num_iter; ++iiter)
00293     {
00294         long_mu.zero();
00295         for (int32_t i=0; i < classes.vlen; ++i)
00296         {
00297             if (mu[i] == 1)
00298                 long_mu[classes[i]] = 1;
00299             else if (mu[i] == -1)
00300                 long_mu[classes[i]] = -1;
00301         }
00302 
00303         SGVector<int32_t> subset(m_feats->get_num_vectors());
00304         SGVector<float64_t> binlab(m_feats->get_num_vectors());
00305         int32_t k=0;
00306 
00307         CMulticlassLabels *labs = dynamic_cast<CMulticlassLabels *>(m_labels);
00308         for (int32_t i=0; i < binlab.vlen; ++i)
00309         {
00310             int32_t lab = labs->get_int_label(i);
00311             binlab[i] = long_mu[lab];
00312             if (long_mu[lab] != 0)
00313                 subset[k++] = i;
00314         }
00315 
00316         subset.vlen = k;
00317 
00318         CBinaryLabels *binary_labels = new CBinaryLabels(binlab);
00319         SG_REF(binary_labels);
00320         binary_labels->add_subset(subset);
00321         m_feats->add_subset(subset);
00322 
00323         CKernel *kernel = (CKernel *)m_kernel->shallow_copy();
00324         kernel->init(m_feats, m_feats);
00325         svm->set_kernel(kernel);
00326         svm->set_labels(binary_labels);
00327         svm->train();
00328 
00329         binary_labels->remove_subset();
00330         m_feats->remove_subset();
00331         SG_UNREF(binary_labels);
00332 
00333         std::copy(&mu[0], &mu[mu.vlen], &prev_mu[0]);
00334 
00335         mu = color_label_space(svm, classes);
00336 
00337         bool bbreak = true;
00338         for (int32_t i=0; i < mu.vlen; ++i)
00339         {
00340             if (mu[i] != prev_mu[i])
00341             {
00342                 bbreak = false;
00343                 break;
00344             }
00345         }
00346 
00347         if (bbreak)
00348             break;
00349     }
00350 
00351     return mu;
00352 }
00353 
00354 struct EntryComparator
00355 {
00356     bool operator() (const CRelaxedTree::entry_t& e1, const CRelaxedTree::entry_t& e2)
00357     {
00358         return e1.second < e2.second;
00359     }
00360 };
00361 std::vector<CRelaxedTree::entry_t> CRelaxedTree::init_node(const SGMatrix<float64_t> &global_conf_mat, SGVector<int32_t> classes)
00362 {
00363     // local confusion matrix
00364     SGMatrix<float64_t> conf_mat(classes.vlen, classes.vlen);
00365     for (index_t i=0; i < conf_mat.num_rows; ++i)
00366     {
00367         for (index_t j=0; j < conf_mat.num_cols; ++j)
00368         {
00369             conf_mat(i, j) = global_conf_mat(classes[i], classes[j]);
00370         }
00371     }
00372 
00373     // make conf matrix symmetry
00374     for (index_t i=0; i < conf_mat.num_rows; ++i)
00375     {
00376         for (index_t j=0; j < conf_mat.num_cols; ++j)
00377         {
00378             conf_mat(i,j) += conf_mat(j,i);
00379         }
00380     }
00381 
00382     std::vector<CRelaxedTree::entry_t> entries;
00383     for (index_t i=0; i < classes.vlen; ++i)
00384     {
00385         for (index_t j=i+1; j < classes.vlen; ++j)
00386         {
00387             entries.push_back(std::make_pair(std::make_pair(i, j), conf_mat(i,j)));
00388         }
00389     }
00390 
00391     std::sort(entries.begin(), entries.end(), EntryComparator());
00392 
00393     const size_t max_n_samples = 30;
00394     int32_t n_samples = std::min(max_n_samples, entries.size());
00395 
00396     return std::vector<CRelaxedTree::entry_t>(entries.begin(), entries.begin() + n_samples);
00397 }
00398 
00399 SGVector<int32_t> CRelaxedTree::color_label_space(CSVM *svm, SGVector<int32_t> classes)
00400 {
00401     SGVector<int32_t> mu(classes.vlen);
00402     CMulticlassLabels *labels = dynamic_cast<CMulticlassLabels *>(m_labels);
00403 
00404     SGVector<float64_t> resp = eval_binary_model_K(svm);
00405     ASSERT(resp.vlen == labels->get_num_labels());
00406 
00407     SGVector<float64_t> xi_pos_class(classes.vlen), xi_neg_class(classes.vlen);
00408     SGVector<float64_t> delta_pos(classes.vlen), delta_neg(classes.vlen);
00409 
00410     for (int32_t i=0; i < classes.vlen; ++i)
00411     {
00412         // find number of instances from this class
00413         int32_t ni=0;
00414         for (int32_t j=0; j < labels->get_num_labels(); ++j)
00415         {
00416             if (labels->get_int_label(j) == classes[i])
00417             {
00418                 ni++;
00419             }
00420         }
00421 
00422         xi_pos_class[i] = 0;
00423         xi_neg_class[i] = 0;
00424         for (int32_t j=0; j < resp.vlen; ++j)
00425         {
00426             if (labels->get_int_label(j) == classes[i])
00427             {
00428                 xi_pos_class[i] += std::max(0.0, 1 - resp[j]);
00429                 xi_neg_class[i] += std::max(0.0, 1 + resp[j]);
00430             }
00431         }
00432 
00433         delta_pos[i] = 1.0/ni * xi_pos_class[i] - float64_t(m_A)/m_svm_C;
00434         delta_neg[i] = 1.0/ni * xi_neg_class[i] - float64_t(m_A)/m_svm_C;
00435 
00436         if (delta_pos[i] > 0 && delta_neg[i] > 0)
00437         {
00438             mu[i] = 0;
00439         }
00440         else
00441         {
00442             if (delta_pos[i] < delta_neg[i])
00443                 mu[i] = 1;
00444             else
00445                 mu[i] = -1;
00446         }
00447 
00448     }
00449 
00450     // enforce balance constraints
00451     int32_t B_prime = 0;
00452     for (int32_t i=0; i < mu.vlen; ++i)
00453         B_prime += mu[i];
00454 
00455     if (B_prime > m_B)
00456     {
00457         enforce_balance_constraints_upper(mu, delta_neg, delta_pos, B_prime, xi_neg_class);
00458     }
00459     if (B_prime < -m_B)
00460     {
00461         enforce_balance_constraints_lower(mu, delta_neg, delta_pos, B_prime, xi_neg_class);
00462     }
00463 
00464     int32_t npos = 0;
00465     for (index_t i=0; i < mu.vlen; ++i)
00466     {
00467         if (mu[i] == 1)
00468             npos++;
00469     }
00470 
00471     if (npos == 0)
00472     {
00473         // no positive class
00474         index_t min_idx = SGVector<float64_t>::arg_min(xi_pos_class.vector, 1, xi_pos_class.vlen);
00475         mu[min_idx] = 1;
00476     }
00477 
00478     int32_t nneg = 0;
00479     for (index_t i=0; i < mu.vlen; ++i)
00480     {
00481         if (mu[i] == -1)
00482             nneg++;
00483     }
00484 
00485     if (nneg == 0)
00486     {
00487         // no negative class
00488         index_t min_idx = SGVector<float64_t>::arg_min(xi_neg_class.vector, 1, xi_neg_class.vlen);
00489         if (mu[min_idx] == 1 && (npos == 0 || npos == 1))
00490         {
00491             // avoid overwritting the only positive class
00492             float64_t min_val = 0;
00493             int32_t i, min_i;
00494             for (i=0; i < xi_neg_class.vlen; ++i)
00495             {
00496                 if (mu[i] != 1)
00497                 {
00498                     min_val = xi_neg_class[i];
00499                     break;
00500                 }
00501             }
00502             min_i = i;
00503             for (i=i+1; i < xi_neg_class.vlen; ++i)
00504             {
00505                 if (mu[i] != 1 && xi_neg_class[i] < min_val)
00506                 {
00507                     min_val = xi_neg_class[i];
00508                     min_i = i;
00509                 }
00510             }
00511             mu[min_i] = -1;
00512         }
00513         else
00514         {
00515             mu[min_idx] = -1;
00516         }
00517     }
00518 
00519     return mu;
00520 }
00521 
00522 void CRelaxedTree::enforce_balance_constraints_upper(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, 
00523         SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class)
00524 {
00525     SGVector<index_t> index_zero = mu.find(0);
00526     SGVector<index_t> index_pos = mu.find_if(std::bind1st(std::less<int32_t>(), 0)); 
00527 
00528     int32_t num_zero = index_zero.vlen;
00529     int32_t num_pos  = index_pos.vlen;
00530 
00531     SGVector<index_t> class_index(num_zero+2*num_pos);
00532     std::copy(&index_zero[0], &index_zero[num_zero], &class_index[0]);
00533     std::copy(&index_pos[0], &index_pos[num_pos], &class_index[num_zero]);
00534     std::copy(&index_pos[0], &index_pos[num_pos], &class_index[num_pos+num_zero]);
00535 
00536     SGVector<int32_t> orig_mu(num_zero + 2*num_pos);
00537     orig_mu.zero();
00538     std::fill(&orig_mu[num_zero], &orig_mu[orig_mu.vlen], 1);
00539 
00540     SGVector<int32_t> delta_steps(num_zero+2*num_pos);
00541     std::fill(&delta_steps[0], &delta_steps[delta_steps.vlen], 1);
00542 
00543     SGVector<int32_t> new_mu(num_zero + 2*num_pos);
00544     new_mu.zero();
00545     std::fill(&new_mu[0], &new_mu[num_zero], -1);
00546 
00547     SGVector<float64_t> S_delta(num_zero + 2*num_pos);
00548     S_delta.zero();
00549     for (index_t i=0; i < num_zero; ++i)
00550         S_delta[i] = delta_neg[index_zero[i]];
00551 
00552     for (int32_t i=0; i < num_pos; ++i)
00553     {
00554         float64_t delta_k = delta_neg[index_pos[i]];
00555         float64_t delta_k_0 = -delta_pos[index_pos[i]];
00556 
00557         index_t tmp_index = num_zero + i*2;
00558         if (delta_k_0 <= delta_k)
00559         {
00560             new_mu[tmp_index] = 0;
00561             new_mu[tmp_index+1] = -1;
00562 
00563             S_delta[tmp_index] = delta_k_0;
00564             S_delta[tmp_index+1] = delta_k;
00565 
00566             delta_steps[tmp_index] = 1;
00567             delta_steps[tmp_index+1] = 1;
00568         }
00569         else
00570         {
00571             new_mu[tmp_index] = -1;
00572             new_mu[tmp_index+1] = 0;
00573 
00574             S_delta[tmp_index] = (delta_k_0+delta_k)/2;
00575             S_delta[tmp_index+1] = delta_k_0;
00576 
00577             delta_steps[tmp_index] = 2;
00578             delta_steps[tmp_index+1] = 1;
00579         }
00580     }
00581 
00582     SGVector<index_t> sorted_index = S_delta.sorted_index();
00583     SGVector<float64_t> S_delta_sorted(S_delta.vlen);
00584     for (index_t i=0; i < sorted_index.vlen; ++i)
00585     {
00586         S_delta_sorted[i] = S_delta[sorted_index[i]];
00587         new_mu[i] = new_mu[sorted_index[i]];
00588         orig_mu[i] = orig_mu[sorted_index[i]];
00589         class_index[i] = class_index[sorted_index[i]];
00590         delta_steps[i] = delta_steps[sorted_index[i]];
00591     }
00592 
00593     SGVector<int32_t> valid_flag(S_delta.vlen);
00594     std::fill(&valid_flag[0], &valid_flag[valid_flag.vlen], 1);
00595 
00596     int32_t d=0;
00597     int32_t ctr=0;
00598 
00599     while (true)
00600     {
00601         if (d == B_prime - m_B || d == B_prime - m_B + 1)
00602             break;
00603 
00604         while (!valid_flag[ctr])
00605             ctr++;
00606 
00607         if (delta_steps[ctr] == 1)
00608         {
00609             mu[class_index[ctr]] = new_mu[ctr];
00610             d++;
00611         }
00612         else
00613         {
00614             // this case should happen only when rho >= 1
00615             if (d <= B_prime - m_B - 2)
00616             {
00617                 mu[class_index[ctr]] = new_mu[ctr];
00618                 ASSERT(new_mu[ctr] == -1);
00619                 d += 2;
00620                 for (index_t i=0; i < class_index.vlen; ++i)
00621                 {
00622                     if (class_index[i] == class_index[ctr])
00623                         valid_flag[i] = 0;
00624                 }
00625             }
00626             else
00627             {
00628                 float64_t Delta_k_minus = 2*S_delta_sorted[ctr];
00629 
00630                 // find the next smallest Delta_j or Delta_{j,0}
00631                 float64_t Delta_j_min=0;
00632                 int32_t j=0;
00633                 for (int32_t itr=ctr+1; itr < S_delta_sorted.vlen; ++itr)
00634                 {
00635                     if (valid_flag[itr] == 0)
00636                         continue;
00637 
00638                     if (delta_steps[itr] == 1)
00639                     {
00640                         j=itr;
00641                         Delta_j_min = S_delta_sorted[j];
00642                     }
00643                 }
00644 
00645                 // find the largest Delta_i or Delta_{i,0}
00646                 float64_t Delta_i_max = 0;
00647                 int32_t i=-1;
00648                 for (int32_t itr=ctr-1; itr >= 0; --itr)
00649                 {
00650                     if (delta_steps[itr] == 1 && valid_flag[itr] == 1)
00651                     {
00652                         i = itr;
00653                         Delta_i_max = S_delta_sorted[i];
00654                     }
00655                 }
00656 
00657                 // find the l with the largest Delta_l_minus - Delta_l_0
00658                 float64_t Delta_l_max = std::numeric_limits<float64_t>::min();
00659                 int32_t l=-1;
00660                 for (int32_t itr=ctr-1; itr >= 0; itr--)
00661                 {
00662                     if (delta_steps[itr] == 2)
00663                     {
00664                         float64_t delta_tmp = xi_neg_class[class_index[itr]];
00665                         if (delta_tmp > Delta_l_max)
00666                         {
00667                             l = itr;
00668                             Delta_l_max = delta_tmp;
00669                         }
00670                     }
00671                 }
00672 
00673                 // one-step-min = j
00674                 if (Delta_j_min <= Delta_k_minus - Delta_i_max &&
00675                         Delta_j_min <= Delta_k_minus - Delta_l_max)
00676                 {
00677                     mu[class_index[j]] = new_mu[j];
00678                     d++;
00679                 }
00680                 else
00681                 {
00682                     // one-step-min = Delta_k_minus - Delta_i_max
00683                     if (Delta_k_minus - Delta_i_max <= Delta_j_min &&
00684                             Delta_k_minus - Delta_i_max <= Delta_k_minus - Delta_l_max)
00685                     {
00686                         mu[class_index[ctr]] = -1;
00687                         if (i > 0)
00688                         {
00689                             mu[class_index[i]] = orig_mu[i];
00690                             d++;
00691                         }
00692                         else
00693                         {
00694                             d += 2;
00695                         }
00696                     }
00697                     else
00698                     {
00699                         ASSERT(l > 0);
00700                         mu[class_index[l]] = 0;
00701                         mu[class_index[ctr]] = -1;
00702                         d++;
00703                     }
00704                 }
00705 
00706             }
00707         }
00708     }
00709 }
00710 
00711 void CRelaxedTree::enforce_balance_constraints_lower(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, 
00712         SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class)
00713 {
00714     SGVector<index_t> index_zero = mu.find(0);
00715     SGVector<index_t> index_neg = mu.find_if(std::bind1st(std::greater<int32_t>(), 0)); 
00716 
00717     int32_t num_zero = index_zero.vlen;
00718     int32_t num_neg  = index_neg.vlen;
00719 
00720     SGVector<index_t> class_index(num_zero+2*num_neg);
00721     std::copy(&index_zero[0], &index_zero[num_zero], &class_index[0]);
00722     std::copy(&index_neg[0], &index_neg[num_neg], &class_index[num_zero]);
00723     std::copy(&index_neg[0], &index_neg[num_neg], &class_index[num_neg+num_zero]);
00724 
00725     SGVector<int32_t> orig_mu(num_zero + 2*num_neg);
00726     orig_mu.zero();
00727     std::fill(&orig_mu[num_zero], &orig_mu[orig_mu.vlen], -1);
00728 
00729     SGVector<int32_t> delta_steps(num_zero+2*num_neg);
00730     std::fill(&delta_steps[0], &delta_steps[delta_steps.vlen], 1);
00731 
00732     SGVector<int32_t> new_mu(num_zero + 2*num_neg);
00733     new_mu.zero();
00734     std::fill(&new_mu[0], &new_mu[num_zero], 1);
00735 
00736     SGVector<float64_t> S_delta(num_zero + 2*num_neg);
00737     S_delta.zero();
00738     for (index_t i=0; i < num_zero; ++i)
00739         S_delta[i] = delta_pos[index_zero[i]];
00740 
00741     for (int32_t i=0; i < num_neg; ++i)
00742     {
00743         float64_t delta_k = delta_pos[index_neg[i]];
00744         float64_t delta_k_0 = -delta_neg[index_neg[i]];
00745 
00746         index_t tmp_index = num_zero + i*2;
00747         if (delta_k_0 <= delta_k)
00748         {
00749             new_mu[tmp_index] = 0;
00750             new_mu[tmp_index+1] = 1;
00751 
00752             S_delta[tmp_index] = delta_k_0;
00753             S_delta[tmp_index+1] = delta_k;
00754 
00755             delta_steps[tmp_index] = 1;
00756             delta_steps[tmp_index+1] = 1;
00757         }
00758         else
00759         {
00760             new_mu[tmp_index] = 1;
00761             new_mu[tmp_index+1] = 0;
00762 
00763             S_delta[tmp_index] = (delta_k_0+delta_k)/2;
00764             S_delta[tmp_index+1] = delta_k_0;
00765 
00766             delta_steps[tmp_index] = 2;
00767             delta_steps[tmp_index+1] = 1;
00768         }
00769     }
00770 
00771     SGVector<index_t> sorted_index = S_delta.sorted_index();
00772     SGVector<float64_t> S_delta_sorted(S_delta.vlen);
00773     for (index_t i=0; i < sorted_index.vlen; ++i)
00774     {
00775         S_delta_sorted[i] = S_delta[sorted_index[i]];
00776         new_mu[i] = new_mu[sorted_index[i]];
00777         orig_mu[i] = orig_mu[sorted_index[i]];
00778         class_index[i] = class_index[sorted_index[i]];
00779         delta_steps[i] = delta_steps[sorted_index[i]];
00780     }
00781 
00782     SGVector<int32_t> valid_flag(S_delta.vlen);
00783     std::fill(&valid_flag[0], &valid_flag[valid_flag.vlen], 1);
00784 
00785     int32_t d=0;
00786     int32_t ctr=0;
00787 
00788     while (true)
00789     {
00790         if (d == -m_B - B_prime || d == -m_B - B_prime + 1)
00791             break;
00792 
00793         while (!valid_flag[ctr])
00794             ctr++;
00795 
00796         if (delta_steps[ctr] == 1)
00797         {
00798             mu[class_index[ctr]] = new_mu[ctr];
00799             d++;
00800         }
00801         else
00802         {
00803             // this case should happen only when rho >= 1
00804             if (d >= -m_B - B_prime - 2)
00805             {
00806                 mu[class_index[ctr]] = new_mu[ctr];
00807                 ASSERT(new_mu[ctr] == 1);
00808                 d += 2;
00809 
00810                 for (index_t i=0; i < class_index.vlen; ++i)
00811                 {
00812                     if (class_index[i] == class_index[ctr])
00813                         valid_flag[i] = 0;
00814                 }
00815             }
00816             else
00817             {
00818                 float64_t Delta_k_minus = 2*S_delta_sorted[ctr];
00819 
00820                 // find the next smallest Delta_j or Delta_{j,0}
00821                 float64_t Delta_j_min=0;
00822                 int32_t j=0;
00823                 for (int32_t itr=ctr+1; itr < S_delta_sorted.vlen; ++itr)
00824                 {
00825                     if (valid_flag[itr] == 0)
00826                         continue;
00827 
00828                     if (delta_steps[itr] == 1)
00829                     {
00830                         j=itr;
00831                         Delta_j_min = S_delta_sorted[j];
00832                     }
00833                 }
00834 
00835                 // find the largest Delta_i or Delta_{i,0}
00836                 float64_t Delta_i_max = 0;
00837                 int32_t i=-1;
00838                 for (int32_t itr=ctr-1; itr >= 0; --itr)
00839                 {
00840                     if (delta_steps[itr] == 1 && valid_flag[itr] == 1)
00841                     {
00842                         i = itr;
00843                         Delta_i_max = S_delta_sorted[i];
00844                     }
00845                 }
00846 
00847                 // find the l with the largest Delta_l_minus - Delta_l_0
00848                 float64_t Delta_l_max = std::numeric_limits<float64_t>::min();
00849                 int32_t l=-1;
00850                 for (int32_t itr=ctr-1; itr >= 0; itr--)
00851                 {
00852                     if (delta_steps[itr] == 2)
00853                     {
00854                         float64_t delta_tmp = xi_neg_class[class_index[itr]];
00855                         if (delta_tmp > Delta_l_max)
00856                         {
00857                             l = itr;
00858                             Delta_l_max = delta_tmp;
00859                         }
00860                     }
00861                 }
00862 
00863                 // one-step-min = j
00864                 if (Delta_j_min <= Delta_k_minus - Delta_i_max &&
00865                         Delta_j_min <= Delta_k_minus - Delta_l_max)
00866                 {
00867                     mu[class_index[j]] = new_mu[j];
00868                     d++;
00869                 }
00870                 else
00871                 {
00872                     // one-step-min = Delta_k_minus - Delta_i_max
00873                     if (Delta_k_minus - Delta_i_max <= Delta_j_min &&
00874                             Delta_k_minus - Delta_i_max <= Delta_k_minus - Delta_l_max)
00875                     {
00876                         mu[class_index[ctr]] = -1;
00877                         if (i > 0)
00878                         {
00879                             mu[class_index[i]] = orig_mu[i];
00880                             d++;
00881                         }
00882                         else
00883                         {
00884                             d += 2;
00885                         }
00886                     }
00887                     else
00888                     {
00889                         ASSERT(l > 0);
00890                         mu[class_index[l]] = 0;
00891                         mu[class_index[ctr]] = -1;
00892                         d++;
00893                     }
00894                 }
00895 
00896             }
00897         }
00898     }
00899 }
00900 
00901 SGVector<float64_t> CRelaxedTree::eval_binary_model_K(CSVM *svm)
00902 {
00903     CRegressionLabels *lab = svm->apply_regression(m_feats);
00904     SGVector<float64_t> resp(lab->get_num_labels());
00905     for (int32_t i=0; i < resp.vlen; ++i)
00906         resp[i] = lab->get_label(i) - m_A/m_svm_C;
00907     SG_UNREF(lab);
00908     return resp;
00909 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation