TaskTree.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/transfer/multitask/TaskTree.h>
00011 #include <vector>
00012 
00013 using namespace std;
00014 using namespace shogun;
00015 
00016 struct task_tree_node_t
00017 {
00018     task_tree_node_t(int32_t min, int32_t max, float64_t w)
00019     {
00020         t_min_index = min;
00021         t_max_index = max;
00022         weight = w;
00023     }
00024     int32_t t_min_index, t_max_index;
00025     float64_t weight;
00026 };
00027 
00028 int32_t count_leaf_tasks_recursive(CTask* subtree_root_block)
00029 {
00030     CList* sub_tasks = subtree_root_block->get_subtasks();
00031     int32_t n_sub_tasks = sub_tasks->get_num_elements();
00032     if (n_sub_tasks==0)
00033     {
00034         SG_UNREF(sub_tasks);
00035         return 1;
00036     }
00037     else
00038     {
00039         int32_t sum = 0;
00040         CTask* iterator = (CTask*)sub_tasks->get_first_element();
00041         do
00042         {
00043             sum += count_leaf_tasks_recursive(iterator);
00044             SG_UNREF(iterator);
00045         }
00046         while ((iterator = (CTask*)sub_tasks->get_next_element()) != NULL);
00047 
00048         SG_UNREF(sub_tasks);
00049         return sum;
00050     }
00051 }
00052 
00053 void collect_tree_tasks_recursive(CTask* subtree_root_block, vector<task_tree_node_t>* tree_nodes, int low)
00054 {
00055     int32_t lower = low;
00056     CList* sub_blocks = subtree_root_block->get_subtasks();
00057     if (sub_blocks->get_num_elements()>0)
00058     {
00059         CTask* iterator = (CTask*)sub_blocks->get_first_element();
00060         do
00061         {
00062             if (iterator->get_num_subtasks()>0)
00063             {
00064                 int32_t n_leaves = count_leaf_tasks_recursive(iterator);
00065                 //SG_SDEBUG("Block [%d %d] has %d leaf childs \n",iterator->get_min_index(), iterator->get_max_index(), n_leaves);
00066                 tree_nodes->push_back(task_tree_node_t(lower,lower+n_leaves-1,iterator->get_weight()));
00067                 collect_tree_tasks_recursive(iterator, tree_nodes, lower);
00068                 lower = lower + n_leaves;
00069             }
00070             else
00071                 lower++;
00072             SG_UNREF(iterator);
00073         }
00074         while ((iterator = (CTask*)sub_blocks->get_next_element()) != NULL);
00075     }
00076     SG_UNREF(sub_blocks);
00077 }
00078 
00079 void collect_leaf_tasks_recursive(CTask* subtree_root_block, CList* list)
00080 {
00081     CList* sub_blocks = subtree_root_block->get_subtasks();
00082     if (sub_blocks->get_num_elements() == 0)
00083     {
00084         list->append_element(subtree_root_block);
00085     }
00086     else
00087     {
00088         CTask* iterator = (CTask*)sub_blocks->get_first_element();
00089         do
00090         {
00091             collect_leaf_tasks_recursive(iterator, list);
00092             SG_UNREF(iterator);
00093         } 
00094         while ((iterator = (CTask*)sub_blocks->get_next_element()) != NULL);
00095     }
00096     SG_UNREF(sub_blocks);
00097 }
00098 
00099 int32_t count_leaft_tasks_recursive(CTask* subtree_root_block)
00100 {
00101     CList* sub_blocks = subtree_root_block->get_subtasks();
00102     int32_t r = 0;
00103     if (sub_blocks->get_num_elements() == 0)
00104     {
00105         return 1;
00106     }
00107     else
00108     {
00109         CTask* iterator = (CTask*)sub_blocks->get_first_element();
00110         do
00111         {
00112             r += count_leaf_tasks_recursive(iterator);
00113             SG_UNREF(iterator);
00114         } 
00115         while ((iterator = (CTask*)sub_blocks->get_next_element()) != NULL);
00116     }
00117     SG_UNREF(sub_blocks);
00118     return r;
00119 }
00120 
00121 CTaskTree::CTaskTree() : CTaskRelation(),
00122     m_root_task(NULL)
00123 {
00124 
00125 }
00126 
00127 CTaskTree::CTaskTree(CTask* root_task) : CTaskRelation(),
00128     m_root_task(NULL)
00129 {
00130     set_root_task(root_task);
00131 }
00132 
00133 CTaskTree::~CTaskTree()
00134 {
00135     SG_UNREF(m_root_task);
00136 }
00137 
00138 SGVector<index_t>* CTaskTree::get_tasks_indices() const
00139 {
00140     CList* blocks = new CList(true);
00141     collect_leaf_tasks_recursive(m_root_task, blocks);
00142     SG_DEBUG("Collected %d leaf blocks\n", blocks->get_num_elements());
00143     //check_blocks_list(blocks);
00144 
00145     //SGVector<index_t> ind(blocks->get_num_elements()+1);
00146 
00147     int t_i = 0;
00148     //ind[0] = 0;
00149     //
00150     SGVector<index_t>* tasks_indices = SG_MALLOC(SGVector<index_t>, blocks->get_num_elements());
00151     CTask* iterator = (CTask*)blocks->get_first_element();
00152     do
00153     {
00154         new (&tasks_indices[t_i]) SGVector<index_t>();
00155         tasks_indices[t_i] = iterator->get_indices();
00156         //REQUIRE(iterator->is_contiguous(),"Task is not contiguous");
00157         //ind[t_i+1] = iterator->get_indices()[iterator->get_indices().vlen-1] + 1;
00158         //SG_DEBUG("Block = [%d,%d]\n", iterator->get_min_index(), iterator->get_max_index());
00159         SG_UNREF(iterator);
00160         t_i++;
00161     } 
00162     while ((iterator = (CTask*)blocks->get_next_element()) != NULL);
00163 
00164     SG_UNREF(blocks);
00165 
00166     return tasks_indices;
00167 }
00168 
00169 int32_t CTaskTree::get_num_tasks() const
00170 {
00171     return count_leaf_tasks_recursive(m_root_task);
00172 }
00173 
00174 SGVector<float64_t> CTaskTree::get_SLEP_ind_t()
00175 {
00176     int n_blocks = get_num_tasks() - 1;
00177     SG_DEBUG("Number of blocks = %d \n", n_blocks);
00178 
00179     vector<task_tree_node_t> tree_nodes = vector<task_tree_node_t>();
00180     
00181     collect_tree_tasks_recursive(m_root_task, &tree_nodes,1);
00182 
00183     SGVector<float64_t> ind_t(3+3*tree_nodes.size());
00184     // supernode
00185     ind_t[0] = -1;
00186     ind_t[1] = -1;
00187     ind_t[2] = 1.0;
00188 
00189     for (int32_t i=0; i<(int32_t)tree_nodes.size(); i++)
00190     {
00191         ind_t[3+i*3] = tree_nodes[i].t_min_index;
00192         ind_t[3+i*3+1] = tree_nodes[i].t_max_index;
00193         ind_t[3+i*3+2] = tree_nodes[i].weight;
00194     }
00195 
00196     return ind_t;
00197 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation