00001 /* 00002 * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights 00003 * embodied in the content of this file are licensed under the BSD 00004 * (revised) open source license. 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * Written (W) 2011 Shashwat Lal Das 00012 * Adaptation of Vowpal Wabbit v5.1. 00013 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society. 00014 */ 00015 00016 #ifndef _VW_REGRESSOR_H__ 00017 #define _VW_REGRESSOR_H__ 00018 00019 #include <shogun/base/SGObject.h> 00020 #include <shogun/lib/DataType.h> 00021 #include <shogun/classifier/vw/VwEnvironment.h> 00022 #include <shogun/loss/LossFunction.h> 00023 00024 namespace shogun 00025 { 00026 00035 class CVwRegressor: public CSGObject 00036 { 00037 public: 00041 CVwRegressor(); 00042 00048 CVwRegressor(CVwEnvironment* env_to_use); 00049 00053 virtual ~CVwRegressor(); 00054 00063 inline float64_t get_loss(float64_t prediction, float64_t label) 00064 { 00065 return loss->loss(prediction, label); 00066 } 00067 00078 inline float64_t get_update(float64_t prediction, float64_t label, 00079 float64_t eta_t, float64_t norm) 00080 { 00081 return loss->get_update(prediction, label, eta_t, norm); 00082 } 00083 00090 virtual void dump_regressor(char* reg_name, bool as_text); 00091 00097 virtual void load_regressor(char* file_name); 00098 00103 virtual const char* get_name() const { return "VwRegressor"; } 00104 00110 virtual void init(CVwEnvironment* env_to_use = NULL); 00111 00112 public: 00114 float32_t** weight_vectors; 00116 CLossFunction* loss; 00117 00118 protected: 00120 CVwEnvironment* env; 00121 }; 00122 00123 } 00124 #endif // _VW_REGRESSOR_H__