SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
LogDetEstimator.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Soumyajit De
8  */
9 
10 #include <shogun/lib/common.h>
11 #include <shogun/lib/SGVector.h>
12 #include <shogun/lib/SGMatrix.h>
20 
21 namespace shogun
22 {
23 
25  : CSGObject()
26 {
27  init();
28 }
29 
31  COperatorFunction<float64_t>* operator_log,
32  CIndependentComputationEngine* computation_engine)
33  : CSGObject()
34 {
35  init();
36 
37  m_trace_sampler=trace_sampler;
38  SG_REF(m_trace_sampler);
39 
40  m_operator_log=operator_log;
41  SG_REF(m_operator_log);
42 
43  m_computation_engine=computation_engine;
44  SG_REF(m_computation_engine);
45 }
46 
47 void CLogDetEstimator::init()
48 {
49  m_trace_sampler=NULL;
50  m_operator_log=NULL;
51  m_computation_engine=NULL;
52 
53  SG_ADD((CSGObject**)&m_trace_sampler, "trace_sampler",
54  "Trace sampler for the log operator", MS_NOT_AVAILABLE);
55 
56  SG_ADD((CSGObject**)&m_operator_log, "operator_log",
57  "The log operator function", MS_NOT_AVAILABLE);
58 
59  SG_ADD((CSGObject**)&m_computation_engine, "computation_engine",
60  "The computation engine for the jobs", MS_NOT_AVAILABLE);
61 }
62 
64 {
65  SG_UNREF(m_trace_sampler);
66  SG_UNREF(m_operator_log);
67  SG_UNREF(m_computation_engine);
68 }
69 
71 {
72  SG_DEBUG("Entering\n");
73  SG_INFO("Computing %d log-det estimates\n", num_estimates);
74 
75  REQUIRE(m_operator_log, "Operator function is NULL\n");
76  // call the precompute of operator function to compute the prerequisites
77  m_operator_log->precompute();
78 
79  REQUIRE(m_trace_sampler, "Trace sampler is NULL\n");
80  // call the precompute of the sampler
81  m_trace_sampler->precompute();
82 
83  REQUIRE(m_operator_log->get_operator()->get_dimension()\
84  ==m_trace_sampler->get_dimension(),
85  "Mismatch in dimensions of the operator and trace-sampler, %d vs %d!\n",
86  m_operator_log->get_operator()->get_dimension(),
87  m_trace_sampler->get_dimension());
88 
89  // for storing the aggregators that submit_jobs return
90  CDynamicObjectArray* aggregators=new CDynamicObjectArray();
91  index_t num_trace_samples=m_trace_sampler->get_num_samples();
92 
93  for (index_t i=0; i<num_estimates; ++i)
94  {
95  for (index_t j=0; j<num_trace_samples; ++j)
96  {
97  SG_INFO("Computing log-determinant trace sample %d/%d\n", j,
98  num_trace_samples);
99 
100  SG_DEBUG("Creating job for estimate %d, trace sample %d/%d\n", i, j,
101  num_trace_samples);
102  // get the trace sampler vector
103  SGVector<float64_t> s=m_trace_sampler->sample(j);
104  // create jobs with the sample vector and store the aggregator
105  CJobResultAggregator* agg=m_operator_log->submit_jobs(s);
106  aggregators->append_element(agg);
107  SG_UNREF(agg);
108  }
109  }
110 
111  REQUIRE(m_computation_engine, "Computation engine is NULL\n");
112 
113  // wait for all the jobs to be completed
114  SG_INFO("Waiting for jobs to finish\n");
115  m_computation_engine->wait_for_all();
116  SG_INFO("All jobs finished, aggregating results\n");
117 
118  // the samples vector which stores the estimates with averaging
119  SGVector<float64_t> samples(num_estimates);
120  samples.zero();
121 
122  // use the aggregators to find the final result
123  // use the same order as job submission to combine results
124  int32_t num_aggregates=aggregators->get_num_elements();
125  index_t idx_row=0;
126  index_t idx_col=0;
127  for (int32_t i=0; i<num_aggregates; ++i)
128  {
129  // this cast is safe due to above way of building the array
130  CJobResultAggregator* agg=dynamic_cast<CJobResultAggregator*>
131  (aggregators->get_element(i));
132  ASSERT(agg);
133 
134  // call finalize on all the aggregators, cast is safe again
135  agg->finalize();
137  (agg->get_final_result());
138  ASSERT(r);
139 
140  // iterate through indices, group results in the same way as jobs
141  samples[idx_col]+=r->get_result();
142  idx_row++;
143  if (idx_row>=num_trace_samples)
144  {
145  idx_row=0;
146  idx_col++;
147  }
148 
149  SG_UNREF(agg);
150  }
151 
152  // clear all aggregators
153  SG_UNREF(aggregators)
154 
155  SG_INFO("Finished computing %d log-det estimates\n", num_estimates);
156 
157  SG_DEBUG("Leaving\n");
158  return samples;
159 }
160 
162  index_t num_estimates)
163 {
164  SG_DEBUG("Entering...\n")
165 
166  REQUIRE(m_operator_log, "Operator function is NULL\n");
167  // call the precompute of operator function to compute all prerequisites
168  m_operator_log->precompute();
169 
170  REQUIRE(m_trace_sampler, "Trace sampler is NULL\n");
171  // call the precompute of the sampler
172  m_trace_sampler->precompute();
173 
174  // for storing the aggregators that submit_jobs return
175  CDynamicObjectArray aggregators;
176  index_t num_trace_samples=m_trace_sampler->get_num_samples();
177 
178  for (index_t i=0; i<num_estimates; ++i)
179  {
180  for (index_t j=0; j<num_trace_samples; ++j)
181  {
182  // get the trace sampler vector
183  SGVector<float64_t> s=m_trace_sampler->sample(j);
184  // create jobs with the sample vector and store the aggregator
185  CJobResultAggregator* agg=m_operator_log->submit_jobs(s);
186  aggregators.append_element(agg);
187  SG_UNREF(agg);
188  }
189  }
190 
191  REQUIRE(m_computation_engine, "Computation engine is NULL\n");
192  // wait for all the jobs to be completed
193  m_computation_engine->wait_for_all();
194 
195  // the samples matrix which stores the estimates without averaging
196  // dimension: number of trace samples x number of log-det estimates
197  SGMatrix<float64_t> samples(num_trace_samples, num_estimates);
198 
199  // use the aggregators to find the final result
200  int32_t num_aggregates=aggregators.get_num_elements();
201  for (int32_t i=0; i<num_aggregates; ++i)
202  {
203  CJobResultAggregator* agg=dynamic_cast<CJobResultAggregator*>
204  (aggregators.get_element(i));
205  if (!agg)
206  SG_ERROR("Element is not CJobResultAggregator type!\n");
207 
208  // call finalize on all the aggregators
209  agg->finalize();
211  (agg->get_final_result());
212  if (!r)
213  SG_ERROR("Result is not CScalarResult type!\n");
214 
215  // its important that we don't just unref the result here
216  index_t idx_row=i%num_trace_samples;
217  index_t idx_col=i/num_trace_samples;
218  samples(idx_row, idx_col)=r->get_result();
219  SG_UNREF(agg);
220  }
221 
222  // clear all aggregators
223  aggregators.clear_array();
224 
225  SG_DEBUG("Leaving\n")
226  return samples;
227 }
228 
229 }
230 

SHOGUN Machine Learning Toolbox - Documentation