55 #define STDEXT_NAMESPACE __gnu_cxx
56 #define std_hash_map std::map
57 #define std_hash_set std::set
67 #ifndef DOXYGEN_SHOULD_SKIP_THIS
68 struct larank_kcache_s;
69 typedef struct larank_kcache_s larank_kcache_t;
70 struct larank_kcache_s
73 larank_kcache_t *prevbuddy;
74 larank_kcache_t *nextbuddy;
98 LaRankOutput () : beta(NULL), g(NULL), kernel(NULL), l(0)
101 virtual ~LaRankOutput ()
107 void initialize (CKernel* kfunc, int64_t cache);
116 float64_t computeGradient (int32_t xi_id, int32_t yi, int32_t ythis);
124 void set_kernel_buddy (larank_kcache_t * bud);
132 inline larank_kcache_t *getKernel ()
const
137 inline int32_t get_l ()
const
161 bool isSupportVector (int32_t x_id)
const;
171 larank_kcache_t *kernel;
181 LaRankPattern (int32_t x_index, int32_t label)
182 : x_id (x_index), y (label) {}
207 ~LaRankPatterns () {}
209 void insert (
const LaRankPattern & pattern)
211 if (!isPattern (pattern.x_id))
215 std_hash_set < uint32_t >::iterator it = freeidx.begin ();
216 patterns[*it] = pattern;
217 x_id2rank[pattern.x_id] = *it;
222 patterns.push_back (pattern);
223 x_id2rank[pattern.x_id] = patterns.size () - 1;
228 int32_t rank = getPatternRank (pattern.x_id);
229 patterns[rank] = pattern;
233 void remove (uint32_t i)
235 x_id2rank[patterns[i].x_id] = 0;
236 patterns[i].clear ();
242 return patterns.size () == freeidx.size ();
245 uint32_t size ()
const
247 return patterns.size () - freeidx.size ();
250 LaRankPattern & sample ()
255 uint32_t r =
CMath::random(uint32_t(0), uint32_t(patterns.size ()-1));
256 if (patterns[r].exists ())
262 uint32_t getPatternRank (int32_t x_id)
264 return x_id2rank[x_id];
267 bool isPattern (int32_t x_id)
269 return x_id2rank[x_id] != 0;
272 LaRankPattern & getPattern (int32_t x_id)
274 uint32_t rank = x_id2rank[x_id];
275 return patterns[rank];
278 uint32_t maxcount ()
const
280 return patterns.size ();
283 LaRankPattern & operator [] (uint32_t i)
288 const LaRankPattern & operator [] (uint32_t i)
const
294 std_hash_set < uint32_t >freeidx;
295 std::vector < LaRankPattern > patterns;
296 std_hash_map < int32_t, uint32_t >x_id2rank;
300 #endif // DOXYGEN_SHOULD_SKIP_THIS
341 virtual int32_t
add (int32_t x_id, int32_t yi);
347 virtual int32_t
predict (int32_t x_id);
379 virtual const char*
get_name()
const {
return "LaRank"; }
417 typedef std_hash_map < int32_t, LaRankOutput > outputhash_t;
420 outputhash_t outputs;
422 LaRankOutput *getOutput (int32_t index);
425 LaRankPatterns patterns;
428 int32_t nb_seen_examples;
444 struct outputgradient_t
446 outputgradient_t (int32_t result_output,
float64_t result_gradient)
447 : output (result_output), gradient (result_gradient) {}
449 : output (0), gradient (0) {}
454 bool operator < (
const outputgradient_t & og)
const
456 return gradient > og.gradient;
468 struct process_return_t
470 process_return_t (
float64_t dual, int32_t yprediction)
471 : dual_increase (dual), ypred (yprediction) {}
472 process_return_t () {}
478 process_return_t process (
const LaRankPattern & pattern, process_type ptype);
497 return classes.size ();
virtual const char * get_name() const
bool batch_mode
whether to use online learning or batch training
bool operator<(const BaseTag &first, const BaseTag &second)
The class Labels models labels, i.e. class assignments of objects.
bool train_machine(CFeatures *data)
void set_batch_mode(bool enable)
int32_t max_iteration
Max number of iterations before training is stopped.
void(* update)(float *foo, float bar)
virtual int32_t add(int32_t x_id, int32_t yi)
uint32_t class_count() const
class count
virtual EMachineType get_classifier_type()
virtual float64_t computeGap()
virtual uint32_t getNumOutputs() const
all of classes and functions are contained in the shogun namespace
The class Features is the base class of all feature objects.
virtual int32_t predict(int32_t x_id)
the LaRank multiclass SVM machine This implementation uses LaRank algorithm from Bordes, Antoine, et al., 2007. "Solving multiclass support vector machines with LaRank."
void set_tau(float64_t t)
std_hash_set< int32_t > classes
classes
int32_t step
progess output
int32_t get_max_iteration()
void set_max_iteration(int32_t max_iter)