Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/config.h>
00012
00013 #ifdef USE_SVMLIGHT
00014
00015 #include <shogun/classifier/svm/DomainAdaptationSVM.h>
00016 #include <shogun/io/SGIO.h>
00017 #include <iostream>
00018 #include <vector>
00019
00020 using namespace shogun;
00021
00022 CDomainAdaptationSVM::CDomainAdaptationSVM() : CSVMLight()
00023 {
00024 }
00025
00026 CDomainAdaptationSVM::CDomainAdaptationSVM(float64_t C, CKernel* k, CLabels* lab, CSVM* pre_svm, float64_t B_param) : CSVMLight(C, k, lab)
00027 {
00028 init();
00029 init(pre_svm, B_param);
00030 }
00031
00032 CDomainAdaptationSVM::~CDomainAdaptationSVM()
00033 {
00034 SG_UNREF(presvm);
00035 SG_DEBUG("deleting DomainAdaptationSVM\n");
00036 }
00037
00038
00039 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param)
00040 {
00041
00042 SG_REF(pre_svm);
00043
00044 this->presvm=pre_svm;
00045 this->B=B_param;
00046 this->train_factor=1.0;
00047
00048
00049 this->presvm->set_bias(0.0);
00050
00051
00052 is_presvm_sane();
00053 }
00054
00055 bool CDomainAdaptationSVM::is_presvm_sane()
00056 {
00057 if (!presvm) {
00058 SG_ERROR("presvm is null");
00059 }
00060
00061 if (presvm->get_num_support_vectors() == 0) {
00062 SG_ERROR("presvm has no support vectors, please train first");
00063 }
00064
00065 if (presvm->get_bias() != 0) {
00066 SG_ERROR("presvm bias not set to zero");
00067 }
00068
00069 if (presvm->get_kernel()->get_kernel_type() != this->get_kernel()->get_kernel_type()) {
00070 SG_ERROR("kernel types do not agree");
00071 }
00072
00073 if (presvm->get_kernel()->get_feature_type() != this->get_kernel()->get_feature_type()) {
00074 SG_ERROR("feature types do not agree");
00075 }
00076
00077 return true;
00078 }
00079
00080
00081 bool CDomainAdaptationSVM::train_machine(CFeatures* data)
00082 {
00083
00084 if (data)
00085 {
00086 if (labels->get_num_labels() != data->get_num_vectors())
00087 SG_ERROR("Number of training vectors does not match number of labels\n");
00088 kernel->init(data, data);
00089 }
00090
00091 int32_t num_training_points = get_labels()->get_num_labels();
00092
00093
00094 float64_t* lin_term = SG_MALLOC(float64_t, num_training_points);
00095
00096
00097 CFeatures* train_data = get_kernel()->get_lhs();
00098
00099
00100 CLabels* parent_svm_out = presvm->apply(train_data);
00101
00102
00103 for (int32_t i=0; i<num_training_points; i++)
00104 {
00105 lin_term[i] = train_factor * B * get_label(i) * parent_svm_out->get_label(i) - 1.0;
00106 }
00107
00108
00109 this->set_linear_term(SGVector<float64_t>(lin_term, num_training_points));
00110
00111 SG_FREE(lin_term);
00112
00113
00114 bool success = CSVMLight::train_machine();
00115
00116 ASSERT(presvm)
00117
00118 return success;
00119
00120 }
00121
00122
00123 CSVM* CDomainAdaptationSVM::get_presvm()
00124 {
00125 SG_REF(presvm);
00126 return presvm;
00127 }
00128
00129
00130 float64_t CDomainAdaptationSVM::get_B()
00131 {
00132 return B;
00133 }
00134
00135
00136 float64_t CDomainAdaptationSVM::get_train_factor()
00137 {
00138 return train_factor;
00139 }
00140
00141
00142 void CDomainAdaptationSVM::set_train_factor(float64_t factor)
00143 {
00144 train_factor = factor;
00145 }
00146
00147
00148 CLabels* CDomainAdaptationSVM::apply(CFeatures* data)
00149 {
00150
00151 ASSERT(presvm->get_bias()==0.0);
00152
00153 int32_t num_examples = data->get_num_vectors();
00154
00155 CLabels* out_current = CSVMLight::apply(data);
00156
00157
00158 CLabels* out_presvm = presvm->apply(data);
00159
00160
00161
00162 for (int32_t i=0; i!=num_examples; i++)
00163 {
00164 float64_t out_combined = out_current->get_label(i) + B*out_presvm->get_label(i);
00165 out_current->set_label(i, out_combined);
00166 }
00167
00168 return out_current;
00169
00170 }
00171
00172 void CDomainAdaptationSVM::init()
00173 {
00174 presvm = NULL;
00175 B = 0;
00176 train_factor = 1.0;
00177
00178 m_parameters->add((CSGObject**) &presvm, "presvm",
00179 "SVM to regularize against.");
00180 m_parameters->add(&B, "B", "regularization parameter B.");
00181 m_parameters->add(&train_factor,
00182 "train_factor", "flag to switch off regularization in training.");
00183 }
00184
00185 #endif //USE_SVMLIGHT