00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00021
00022 #ifndef _SSL_H
00023 #define _SSL_H
00024
00025
00026 #define CGITERMAX 10000
00027 #define SMALL_CGITERMAX 10
00028 #define EPSILON 1e-6
00029 #define BIG_EPSILON 0.01
00030 #define RELATIVE_STOP_EPS 1e-9
00031 #define MFNITERMAX 50
00032 #define TSVM_ANNEALING_RATE 1.5
00033 #define TSVM_LAMBDA_SMALL 1e-5
00034 #define DA_ANNEALING_RATE 1.5
00035 #define DA_INIT_TEMP 10
00036 #define DA_INNER_ITERMAX 100
00037 #define DA_OUTER_ITERMAX 30
00038
00039 #include <shogun/lib/common.h>
00040 #include <shogun/features/DotFeatures.h>
00041
00042 namespace shogun
00043 {
00045 struct data
00046 {
00048 int32_t m;
00050 int32_t l;
00052 int32_t u;
00054 int32_t n;
00056 int32_t nz;
00057
00059 shogun::CDotFeatures* features;
00061 float64_t *Y;
00063 float64_t *C;
00064 };
00065
00067 struct vector_double
00068 {
00070 int32_t d;
00072 float64_t *vec;
00073 };
00074
00076 struct vector_int
00077 {
00079 int32_t d;
00081 int32_t *vec;
00082 };
00083
00084 enum { RLS, SVM, TSVM, DA_SVM };
00085
00087 struct options
00088 {
00089
00091 int32_t algo;
00093 float64_t lambda;
00095 float64_t lambda_u;
00097 int32_t S;
00099 float64_t R;
00101 float64_t Cp;
00103 float64_t Cn;
00104
00105
00107 float64_t epsilon;
00109 int32_t cgitermax;
00111 int32_t mfnitermax;
00112
00114 float64_t bias;
00115 };
00116
00118 class Delta {
00119 public:
00121 Delta() { delta=0.0; index=0;s=0; }
00122
00124 float64_t delta;
00126 int32_t index;
00128 int32_t s;
00129 };
00130
00131 inline bool operator<(const Delta& a , const Delta& b)
00132 {
00133 return (a.delta < b.delta);
00134 }
00135
00136 void initialize(struct vector_double *A, int32_t k, float64_t a);
00137
00138 void initialize(struct vector_int *A, int32_t k);
00139
00140 void GetLabeledData(struct data *Data_Labeled, const struct data *Data);
00141
00142 float64_t norm_square(const vector_double *A);
00143
00144
00145
00146
00147 void ssl_train(
00148 struct data *Data,
00149 struct options *Options,
00150 struct vector_double *W,
00151 struct vector_double *O);
00152
00153
00154
00155
00156
00157
00158 int32_t CGLS(
00159 const struct data *Data,
00160 const struct options *Options,
00161 const struct vector_int *Subset,
00162 struct vector_double *Weights,
00163 struct vector_double *Outputs);
00164
00165
00166
00167 int32_t L2_SVM_MFN(
00168 const struct data *Data,
00169 struct options *Options,
00170 struct vector_double *Weights,
00171 struct vector_double *Outputs,
00172 int32_t ini);
00173
00174 float64_t line_search(
00175 float64_t *w,
00176 float64_t *w_bar,
00177 float64_t lambda,
00178 float64_t *o,
00179 float64_t *o_bar,
00180 float64_t *Y,
00181 float64_t *C,
00182 int32_t d,
00183 int32_t l);
00184
00185
00186
00187
00188 int32_t TSVM_MFN(
00189 const struct data *Data,
00190 struct options *Options,
00191 struct vector_double *Weights,
00192 struct vector_double *Outputs);
00193
00194 int32_t switch_labels(
00195 float64_t* Y,
00196 float64_t* o,
00197 int32_t* JU,
00198 int32_t u,
00199 int32_t S);
00200
00201
00202 int32_t DA_S3VM(
00203 struct data *Data,
00204 struct options *Options,
00205 struct vector_double *Weights,
00206 struct vector_double *Outputs);
00207
00208 void optimize_p(
00209 const float64_t* g, int32_t u, float64_t T, float64_t r, float64_t*p);
00210
00211 int32_t optimize_w(
00212 const struct data *Data,
00213 const float64_t *p,
00214 struct options *Options,
00215 struct vector_double *Weights,
00216 struct vector_double *Outputs,
00217 int32_t ini);
00218
00219 float64_t transductive_cost(
00220 float64_t normWeights,
00221 float64_t *Y,
00222 float64_t *Outputs,
00223 int32_t m,
00224 float64_t lambda,
00225 float64_t lambda_u);
00226
00227 float64_t entropy(const float64_t *p, int32_t u);
00228
00229
00230 float64_t KL(const float64_t *p, const float64_t *q, int32_t u);
00231 }
00232 #endif // _SSL_H
00233
00234 #endif // DOXYGEN_SHOULD_SKIP_THIS