SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
DomainAdaptationMulticlassLibLinear.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Sergey Lisitsyn
8  * Copyright (C) 2012 Sergey Lisitsyn
9  */
10 
11 #include <shogun/lib/config.h>
12 #ifdef HAVE_LAPACK
15 
16 using namespace shogun;
17 
20 {
21  init_defaults();
22 }
23 
25  float64_t target_C, CDotFeatures* target_features, CLabels* target_labels,
26  CLinearMulticlassMachine* source_machine) :
27  CMulticlassLibLinear(target_C,target_features,target_labels)
28 {
29  init_defaults();
30 
31  set_source_machine(source_machine);
32 }
33 
34 void CDomainAdaptationMulticlassLibLinear::init_defaults()
35 {
36  m_train_factor = 1.0;
37  m_source_bias = 0.5;
38  m_source_machine = NULL;
39 
40  register_parameters();
41 }
42 
44 {
45  return m_source_bias;
46 }
47 
49 {
50  m_source_bias = source_bias;
51 }
52 
54 {
55  return m_train_factor;
56 }
57 
59 {
60  m_train_factor = train_factor;
61 }
62 
64 {
66  return m_source_machine;
67 }
68 
70  CLinearMulticlassMachine* source_machine)
71 {
72  SG_REF(source_machine);
74  m_source_machine = source_machine;
75 }
76 
77 void CDomainAdaptationMulticlassLibLinear::register_parameters()
78 {
79  SG_ADD((CSGObject**)&m_source_machine, "source_machine", "source domain machine",
81  SG_ADD(&m_train_factor, "train_factor", "factor of target domain regularization",
82  MS_AVAILABLE);
83  SG_ADD(&m_source_bias, "source_bias", "bias to source domain",
84  MS_AVAILABLE);
85 }
86 
88 {
89 }
90 
92 {
93  ASSERT(get_use_bias()==false)
94  int32_t n_classes = ((CMulticlassLabels*)m_source_machine->get_labels())->get_num_classes();
95  int32_t n_features = ((CDotFeatures*)m_source_machine->get_features())->get_dim_feature_space();
96  SGMatrix<float64_t> w0(n_classes,n_features);
97 
98  for (int32_t i=0; i<n_classes; i++)
99  {
101  for (int32_t j=0; j<n_features; j++)
102  w0(j,i) = m_train_factor*w[j];
103  }
104 
105  return w0;
106 }
107 
109 {
112  int32_t n_target_outputs = target_outputs->get_num_labels();
113  ASSERT(n_target_outputs==source_outputs->get_num_labels())
114  SGVector<float64_t> result(n_target_outputs);
115  for (int32_t j=0; j<result.vlen; j++)
116  result[j] = (1-m_source_bias)*target_outputs->get_value(j) + m_source_bias*source_outputs->get_value(j);
117 
118  SG_UNREF(target_outputs);
119  SG_UNREF(source_outputs);
120 
121  return new CBinaryLabels(result);
122 }
123 #endif /* HAVE_LAPACK */
virtual SGMatrix< float64_t > obtain_regularizer_matrix() const
virtual float64_t get_value(int32_t idx)
Definition: Labels.cpp:59
CMachine * get_machine(int32_t num) const
virtual int32_t get_num_labels() const
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
Features that support dot products among other operations.
Definition: DotFeatures.h:44
#define SG_REF(x)
Definition: SGObject.h:51
Multiclass Labels for multi-class classification.
generic linear multiclass machine
#define ASSERT(x)
Definition: SGIO.h:201
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:112
void set_source_machine(CLinearMulticlassMachine *source_machine)
double float64_t
Definition: common.h:50
multiclass LibLinear wrapper. Uses Crammer-Singer formulation and gradient descent optimization algor...
virtual CBinaryLabels * get_submachine_outputs(int32_t i)
virtual CLabels * get_labels()
Definition: Machine.cpp:76
Class LinearMachine is a generic interface for all kinds of linear machines like classifiers.
Definition: LinearMachine.h:63
#define SG_UNREF(x)
Definition: SGObject.h:52
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
#define SG_ADD(...)
Definition: SGObject.h:81

SHOGUN 机器学习工具包 - 项目文档