LinearClassifier.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _LINEARCLASSIFIER_H__
00012 #define _LINEARCLASSIFIER_H__
00013 
00014 #include "lib/common.h"
00015 #include "features/Labels.h"
00016 #include "features/DotFeatures.h"
00017 #include "classifier/Classifier.h"
00018 
00019 #include <stdio.h>
00020 
00021 namespace shogun
00022 {
00023     class CDotFeatures;
00024     class CLabels;
00025 
00060 class CLinearClassifier : public CClassifier
00061 {
00062     public:
00064         CLinearClassifier();
00065         virtual ~CLinearClassifier();
00066 
00068         virtual inline float64_t classify_example(int32_t vec_idx)
00069         {
00070             return features->dense_dot(vec_idx, w, w_dim) + bias;
00071         }
00072 
00078         inline void get_w(float64_t*& dst_w, int32_t& dst_dims)
00079         {
00080             ASSERT(w && w_dim>0);
00081             dst_w=w;
00082             dst_dims=w_dim;
00083         }
00084 
00090         inline void get_w(float64_t** dst_w, int32_t* dst_dims)
00091         {
00092             ASSERT(dst_w && dst_dims);
00093             ASSERT(w && w_dim>0);
00094             *dst_dims=w_dim;
00095             *dst_w=(float64_t*) malloc(sizeof(float64_t)*(*dst_dims));
00096             ASSERT(*dst_w);
00097             memcpy(*dst_w, w, sizeof(float64_t) * (*dst_dims));
00098         }
00099 
00105         inline void set_w(float64_t* src_w, int32_t src_w_dim)
00106         {
00107             delete[] w;
00108             w=new float64_t[src_w_dim];
00109             memcpy(w, src_w, size_t(src_w_dim)*sizeof(float64_t));
00110             w_dim=src_w_dim;
00111         }
00112 
00117         inline void set_bias(float64_t b)
00118         {
00119             bias=b;
00120         }
00121 
00126         inline float64_t get_bias()
00127         {
00128             return bias;
00129         }
00130 
00136         virtual bool load(FILE* srcfile);
00137 
00143         virtual bool save(FILE* dstfile);
00144 
00149         virtual inline void set_features(CDotFeatures* feat)
00150         {
00151             SG_UNREF(features);
00152             SG_REF(feat);
00153             features=feat;
00154         }
00155 
00160         virtual CLabels* classify();
00161 
00167         virtual CLabels* classify(CFeatures* data);
00168 
00173         virtual CDotFeatures* get_features() { SG_REF(features); return features; }
00174 
00180         virtual const char* get_name(void) const {
00181             return "LinearClassifier"; }
00182 
00183     protected:
00185         int32_t w_dim;
00187         float64_t* w;
00189         float64_t bias;
00191         CDotFeatures* features;
00192 };
00193 }
00194 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation