00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #ifndef MULTITASKLOGISTICREGRESSION_H_ 00011 #define MULTITASKLOGISTICREGRESSION_H_ 00012 00013 #include <shogun/lib/config.h> 00014 #include <shogun/transfer/multitask/MultitaskLinearMachine.h> 00015 #include <shogun/transfer/multitask/TaskRelation.h> 00016 #include <shogun/transfer/multitask/TaskGroup.h> 00017 #include <shogun/transfer/multitask/TaskTree.h> 00018 #include <shogun/transfer/multitask/Task.h> 00019 00020 #include <vector> 00021 #include <set> 00022 00023 using namespace std; 00024 00025 namespace shogun 00026 { 00035 class CMultitaskLogisticRegression : public CMultitaskLinearMachine 00036 { 00037 00038 public: 00040 MACHINE_PROBLEM_TYPE(PT_BINARY) 00041 00042 00043 CMultitaskLogisticRegression(); 00044 00052 CMultitaskLogisticRegression( 00053 float64_t z, CDotFeatures* training_data, 00054 CBinaryLabels* training_labels, CTaskRelation* task_relation); 00055 00057 virtual ~CMultitaskLogisticRegression(); 00058 00060 virtual const char* get_name() const 00061 { 00062 return "MultitaskLogisticRegression"; 00063 } 00064 00066 int32_t get_max_iter() const; 00068 float64_t get_q() const; 00070 int32_t get_regularization() const; 00072 int32_t get_termination() const; 00074 float64_t get_tolerance() const; 00076 float64_t get_z() const; 00077 00079 void set_max_iter(int32_t max_iter); 00081 void set_q(float64_t q); 00083 void set_regularization(int32_t regularization); 00085 void set_termination(int32_t termination); 00087 void set_tolerance(float64_t tolerance); 00089 void set_z(float64_t z); 00090 00092 virtual float64_t apply_one(int32_t i); 00093 00094 protected: 00095 00097 virtual bool train_machine(CFeatures* data=NULL); 00098 00100 virtual bool train_locked_implementation(SGVector<index_t>* tasks); 00101 00102 private: 00103 00105 void register_parameters(); 00106 00108 void initialize_parameters(); 00109 00110 protected: 00111 00113 int32_t m_regularization; 00114 00116 int32_t m_termination; 00117 00119 int32_t m_max_iter; 00120 00122 float64_t m_tolerance; 00123 00125 float64_t m_q; 00126 00128 float64_t m_z; 00129 00130 }; 00131 } 00132 #endif