Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/config.h>
00012
00013 #ifdef HAVE_LAPACK
00014
00015 #include <shogun/classifier/svm/DomainAdaptationSVMLinear.h>
00016 #include <shogun/io/SGIO.h>
00017 #include <shogun/base/Parameter.h>
00018 #include <iostream>
00019 #include <vector>
00020
00021 using namespace shogun;
00022
00023
00024 CDomainAdaptationSVMLinear::CDomainAdaptationSVMLinear() : CLibLinear(L2R_L1LOSS_SVC_DUAL)
00025 {
00026 init(NULL, 0.0);
00027 }
00028
00029
00030 CDomainAdaptationSVMLinear::CDomainAdaptationSVMLinear(float64_t C, CDotFeatures* f, CLabels* lab, CLinearMachine* pre_svm, float64_t B_param) : CLibLinear(C, f, lab)
00031 {
00032 init(pre_svm, B_param);
00033
00034 }
00035
00036
00037 CDomainAdaptationSVMLinear::~CDomainAdaptationSVMLinear()
00038 {
00039
00040 SG_UNREF(presvm);
00041 SG_DEBUG("deleting DomainAdaptationSVMLinear\n");
00042 }
00043
00044
00045 void CDomainAdaptationSVMLinear::init(CLinearMachine* pre_svm, float64_t B_param)
00046 {
00047
00048 if (pre_svm)
00049 {
00050
00051 SG_REF(pre_svm);
00052
00053
00054 pre_svm->set_bias(0.0);
00055 }
00056
00057 this->presvm = pre_svm;
00058 this->B = B_param;
00059 this->train_factor = 1.0;
00060
00061 set_liblinear_solver_type(L2R_L1LOSS_SVC_DUAL);
00062
00063
00064 is_presvm_sane();
00065
00066
00067 m_parameters->add((CSGObject**) &presvm, "presvm", "SVM to regularize against");
00068 m_parameters->add(&B, "B", "Regularization strenth B.");
00069 m_parameters->add(&train_factor, "train_factor", "train_factor");
00070
00071 }
00072
00073
00074 bool CDomainAdaptationSVMLinear::is_presvm_sane()
00075 {
00076
00077 if (!presvm) {
00078
00079 SG_WARNING("presvm is null");
00080
00081 } else {
00082
00083 if (presvm->get_bias() != 0) {
00084 SG_ERROR("presvm bias not set to zero");
00085 }
00086
00087 if (presvm->get_features()->get_feature_type() != this->get_features()->get_feature_type()) {
00088 SG_ERROR("feature types do not agree");
00089 }
00090 }
00091
00092 return true;
00093
00094 }
00095
00096
00097 bool CDomainAdaptationSVMLinear::train_machine(CDotFeatures* train_data)
00098 {
00099
00100 CDotFeatures* tmp_data;
00101
00102 if (train_data)
00103 {
00104 if (labels->get_num_labels() != train_data->get_num_vectors())
00105 SG_ERROR("Number of training vectors does not match number of labels\n");
00106 tmp_data = train_data;
00107
00108 } else {
00109
00110 tmp_data = features;
00111 }
00112
00113 int32_t num_training_points = get_labels()->get_num_labels();
00114
00115 std::vector<float64_t> lin_term = std::vector<float64_t>(num_training_points);
00116
00117 if (presvm)
00118 {
00119 ASSERT(presvm->get_bias() == 0.0);
00120
00121
00122 CLabels* parent_svm_out = presvm->apply(tmp_data);
00123
00124 SG_DEBUG("pre-computing linear term from presvm\n");
00125
00126
00127 for (int32_t i=0; i!=num_training_points; i++)
00128 {
00129 lin_term[i] = train_factor * B * get_label(i) * parent_svm_out->get_label(i) - 1.0;
00130 }
00131
00132
00133 this->set_linear_term(
00134 SGVector<float64_t>(&lin_term[0], lin_term.size()));
00135
00136 }
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159 bool success = false;
00160
00161
00162 if (train_data)
00163 {
00164 success = CLibLinear::train_machine(train_data);
00165 } else {
00166 success = CLibLinear::train_machine();
00167 }
00168
00169
00170
00171 return success;
00172
00173 }
00174
00175
00176 CLinearMachine* CDomainAdaptationSVMLinear::get_presvm()
00177 {
00178 return presvm;
00179 }
00180
00181
00182 float64_t CDomainAdaptationSVMLinear::get_B()
00183 {
00184 return B;
00185 }
00186
00187
00188 float64_t CDomainAdaptationSVMLinear::get_train_factor()
00189 {
00190 return train_factor;
00191 }
00192
00193
00194 void CDomainAdaptationSVMLinear::set_train_factor(float64_t factor)
00195 {
00196 train_factor = factor;
00197 }
00198
00199
00200 CLabels* CDomainAdaptationSVMLinear::apply(CDotFeatures* data)
00201 {
00202
00203 ASSERT(presvm->get_bias()==0.0);
00204
00205 int32_t num_examples = data->get_num_vectors();
00206
00207 CLabels* out_current = CLibLinear::apply(data);
00208
00209 if (presvm)
00210 {
00211
00212
00213 CLabels* out_presvm = presvm->apply(data);
00214
00215
00216
00217 for (int32_t i=0; i!=num_examples; i++)
00218 {
00219 float64_t out_combined = out_current->get_label(i) + B*out_presvm->get_label(i);
00220 out_current->set_label(i, out_combined);
00221 }
00222
00223 }
00224
00225
00226 return out_current;
00227
00228 }
00229
00230 #endif //HAVE_LAPACK
00231