Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/config.h>
00012
00013 #ifdef USE_SVMLIGHT
00014
00015 #include <shogun/transfer/domain_adaptation/DomainAdaptationSVM.h>
00016 #include <shogun/io/SGIO.h>
00017 #include <shogun/labels/Labels.h>
00018 #include <shogun/labels/BinaryLabels.h>
00019 #include <shogun/labels/RegressionLabels.h>
00020 #include <iostream>
00021 #include <vector>
00022
00023 using namespace shogun;
00024
00025 CDomainAdaptationSVM::CDomainAdaptationSVM() : CSVMLight()
00026 {
00027 }
00028
00029 CDomainAdaptationSVM::CDomainAdaptationSVM(float64_t C, CKernel* k, CLabels* lab, CSVM* pre_svm, float64_t B_param) : CSVMLight(C, k, lab)
00030 {
00031 init();
00032 init(pre_svm, B_param);
00033 }
00034
00035 CDomainAdaptationSVM::~CDomainAdaptationSVM()
00036 {
00037 SG_UNREF(presvm);
00038 SG_DEBUG("deleting DomainAdaptationSVM\n");
00039 }
00040
00041
00042 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param)
00043 {
00044
00045 SG_REF(pre_svm);
00046
00047 this->presvm=pre_svm;
00048 this->B=B_param;
00049 this->train_factor=1.0;
00050
00051
00052 this->presvm->set_bias(0.0);
00053
00054
00055 is_presvm_sane();
00056 }
00057
00058 bool CDomainAdaptationSVM::is_presvm_sane()
00059 {
00060 if (!presvm) {
00061 SG_ERROR("presvm is null");
00062 }
00063
00064 if (presvm->get_num_support_vectors() == 0) {
00065 SG_ERROR("presvm has no support vectors, please train first");
00066 }
00067
00068 if (presvm->get_bias() != 0) {
00069 SG_ERROR("presvm bias not set to zero");
00070 }
00071
00072 if (presvm->get_kernel()->get_kernel_type() != this->get_kernel()->get_kernel_type()) {
00073 SG_ERROR("kernel types do not agree");
00074 }
00075
00076 if (presvm->get_kernel()->get_feature_type() != this->get_kernel()->get_feature_type()) {
00077 SG_ERROR("feature types do not agree");
00078 }
00079
00080 return true;
00081 }
00082
00083
00084 bool CDomainAdaptationSVM::train_machine(CFeatures* data)
00085 {
00086
00087 if (data)
00088 {
00089 if (m_labels->get_num_labels() != data->get_num_vectors())
00090 SG_ERROR("Number of training vectors does not match number of labels\n");
00091 kernel->init(data, data);
00092 }
00093
00094 if (m_labels->get_label_type() != LT_BINARY)
00095 SG_ERROR("DomainAdaptationSVM requires binary labels\n");
00096
00097 int32_t num_training_points = get_labels()->get_num_labels();
00098 CBinaryLabels* labels = (CBinaryLabels*) get_labels();
00099
00100 float64_t* lin_term = SG_MALLOC(float64_t, num_training_points);
00101
00102
00103 CFeatures* train_data = get_kernel()->get_lhs();
00104
00105
00106 CBinaryLabels* parent_svm_out = presvm->apply_binary(train_data);
00107
00108
00109 for (int32_t i=0; i<num_training_points; i++)
00110 {
00111 lin_term[i] = train_factor * B * labels->get_label(i) * parent_svm_out->get_label(i) - 1.0;
00112 }
00113
00114
00115 this->set_linear_term(SGVector<float64_t>(lin_term, num_training_points));
00116
00117
00118 bool success = CSVMLight::train_machine();
00119 SG_UNREF(labels);
00120
00121 ASSERT(presvm)
00122
00123 return success;
00124
00125 }
00126
00127
00128 CSVM* CDomainAdaptationSVM::get_presvm()
00129 {
00130 SG_REF(presvm);
00131 return presvm;
00132 }
00133
00134
00135 float64_t CDomainAdaptationSVM::get_B()
00136 {
00137 return B;
00138 }
00139
00140
00141 float64_t CDomainAdaptationSVM::get_train_factor()
00142 {
00143 return train_factor;
00144 }
00145
00146
00147 void CDomainAdaptationSVM::set_train_factor(float64_t factor)
00148 {
00149 train_factor = factor;
00150 }
00151
00152
00153 CBinaryLabels* CDomainAdaptationSVM::apply_binary(CFeatures* data)
00154 {
00155 ASSERT(data);
00156 ASSERT(presvm->get_bias()==0.0);
00157
00158 int32_t num_examples = data->get_num_vectors();
00159
00160 CBinaryLabels* out_current = CSVMLight::apply_binary(data);
00161
00162
00163 CBinaryLabels* out_presvm = presvm->apply_binary(data);
00164
00165
00166 SGVector<float64_t> out_combined(num_examples);
00167 for (int32_t i=0; i<num_examples; i++)
00168 {
00169 out_combined[i] = out_current->get_value(i) + B*out_presvm->get_value(i);
00170 }
00171 SG_UNREF(out_current);
00172 SG_UNREF(out_presvm);
00173
00174 return new CBinaryLabels(out_combined);
00175
00176 }
00177
00178 void CDomainAdaptationSVM::init()
00179 {
00180 presvm = NULL;
00181 B = 0;
00182 train_factor = 1.0;
00183
00184 m_parameters->add((CSGObject**) &presvm, "presvm",
00185 "SVM to regularize against.");
00186 m_parameters->add(&B, "B", "regularization parameter B.");
00187 m_parameters->add(&train_factor,
00188 "train_factor", "flag to switch off regularization in training.");
00189 }
00190
00191 #endif //USE_SVMLIGHT