MultitaskROCEvaluation.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/transfer/multitask/MultitaskROCEvaluation.h>
00011 #include <shogun/mathematics/Math.h>
00012 
00013 #include <set>
00014 #include <vector>
00015 
00016 using namespace std;
00017 using namespace shogun;
00018 
00019 void CMultitaskROCEvaluation::set_indices(SGVector<index_t> indices)
00020 {
00021     indices.display_vector("indices");
00022     ASSERT(m_task_relation);
00023 
00024     set<index_t> indices_set;
00025     for (int32_t i=0; i<indices.vlen; i++)
00026         indices_set.insert(indices[i]);
00027 
00028     if (m_num_tasks>0)
00029     {
00030         for (int32_t t=0; t<m_num_tasks; t++)
00031             m_tasks_indices[t].~SGVector<index_t>();
00032         SG_FREE(m_tasks_indices);
00033     }
00034     m_num_tasks = m_task_relation->get_num_tasks();
00035     m_tasks_indices = SG_MALLOC(SGVector<index_t>, m_num_tasks);
00036 
00037     SGVector<index_t>* tasks_indices = m_task_relation->get_tasks_indices();
00038     for (int32_t t=0; t<m_num_tasks; t++)
00039     {
00040         new (&m_tasks_indices[t]) SGVector<index_t>();
00041         vector<index_t> task_indices_cut;
00042         SGVector<index_t> task_indices = tasks_indices[t];
00043         //task_indices.display_vector("task indices");
00044         for (int32_t i=0; i<task_indices.vlen; i++)
00045         {
00046             if (indices_set.count(task_indices[i]))
00047             {
00048                 //SG_SPRINT("%d is in %d task\n",task_indices[i],t);
00049                 task_indices_cut.push_back(task_indices[i]);
00050             }
00051         }
00052 
00053         SGVector<index_t> cutted(task_indices_cut.size());
00054         for (int32_t i=0; i<cutted.vlen; i++)
00055             cutted[i] = task_indices_cut[i];
00056         //cutted.display_vector("cutted");
00057         m_tasks_indices[t] = cutted;
00058         tasks_indices[t].~SGVector<index_t>();
00059     }
00060     SG_FREE(tasks_indices);
00061 }
00062 
00063 float64_t CMultitaskROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
00064 {
00065     //SG_SPRINT("Evaluate\n");
00066     predicted->remove_all_subsets();
00067     ground_truth->remove_all_subsets();
00068     float64_t result = 0.0;
00069     for (int32_t t=0; t<m_num_tasks; t++)
00070     {
00071         //SG_SPRINT("%d task", t);
00072         //m_tasks_indices[t].display_vector();
00073         predicted->add_subset(m_tasks_indices[t]);
00074         ground_truth->add_subset(m_tasks_indices[t]);
00075         result += evaluate_roc(predicted,ground_truth)/m_tasks_indices[t].vlen;
00076         predicted->remove_subset();
00077         ground_truth->remove_subset();
00078     }
00079     return result;
00080 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation