00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef __SGMATRIX_H__
00013 #define __SGMATRIX_H__
00014
00015 #include <shogun/lib/config.h>
00016 #include <shogun/lib/DataType.h>
00017 #include <shogun/lib/SGReferencedData.h>
00018
00019 namespace shogun
00020 {
00021 template<class T> class SGVector;
00022 template<class T> class SGMatrixList;
00024 template<class T> class SGMatrix : public SGReferencedData
00025 {
00026 public:
00028 SGMatrix() : SGReferencedData()
00029 {
00030 init_data();
00031 }
00032
00034 SGMatrix(T* m, index_t nrows, index_t ncols, bool ref_counting=true)
00035 : SGReferencedData(ref_counting), matrix(m),
00036 num_rows(nrows), num_cols(ncols) { }
00037
00039 SGMatrix(index_t nrows, index_t ncols, bool ref_counting=true)
00040 : SGReferencedData(ref_counting), num_rows(nrows), num_cols(ncols)
00041 {
00042 matrix=SG_MALLOC(T, ((int64_t) nrows)*ncols);
00043 }
00044
00046 SGMatrix(const SGMatrix &orig) : SGReferencedData(orig)
00047 {
00048 copy_data(orig);
00049 }
00050
00052 virtual ~SGMatrix()
00053 {
00054 unref();
00055 }
00056
00060 T* get_column_vector(index_t col) const
00061 {
00062 return &matrix[col*num_rows];
00063 }
00064
00069 inline const T& operator()(index_t i_row, index_t i_col) const
00070 {
00071 return matrix[i_col*num_rows + i_row];
00072 }
00073
00077 inline const T& operator[](index_t index) const
00078 {
00079 return matrix[index];
00080 }
00081
00086 inline T& operator()(index_t i_row, index_t i_col)
00087 {
00088 return matrix[i_col*num_rows + i_row];
00089 }
00090
00094 inline T& operator[](index_t index)
00095 {
00096 return matrix[index];
00097 }
00098
00100 inline bool operator==(SGMatrix<T>& other)
00101 {
00102 if (num_rows!=other.num_rows || num_cols!=other.num_cols)
00103 return false;
00104
00105 if (matrix!=other.matrix)
00106 return false;
00107
00108 return true;
00109 }
00110
00117 inline bool equals(SGMatrix<T>& other)
00118 {
00119 if (num_rows!=other.num_rows || num_cols!=other.num_cols)
00120 return false;
00121
00122 for (index_t i=0; i<num_rows*num_cols; ++i)
00123 {
00124 if (matrix[i]!=other.matrix[i])
00125 return false;
00126 }
00127
00128 return true;
00129 }
00130
00132 void set_const(T const_elem)
00133 {
00134 for (index_t i=0; i<num_rows*num_cols; i++)
00135 matrix[i]=const_elem ;
00136 }
00137
00139 void zero()
00140 {
00141 if (matrix && (num_rows*num_cols))
00142 set_const(0);
00143 }
00144
00146 SGMatrix<T> clone()
00147 {
00148 return SGMatrix<T>(clone_matrix(matrix, num_rows, num_cols),
00149 num_rows, num_cols);
00150 }
00151
00153 static T* clone_matrix(const T* matrix, int32_t nrows, int32_t ncols)
00154 {
00155 T* result = SG_MALLOC(T, int64_t(nrows)*ncols);
00156 for (int64_t i=0; i<int64_t(nrows)*ncols; i++)
00157 result[i]=matrix[i];
00158
00159 return result;
00160 }
00161
00163 static void transpose_matrix(
00164 T*& matrix, int32_t& num_feat, int32_t& num_vec);
00165
00167 static void create_diagonal_matrix(T* matrix, T* v,int32_t size)
00168 {
00169 for(int32_t i=0;i<size;i++)
00170 {
00171 for(int32_t j=0;j<size;j++)
00172 {
00173 if(i==j)
00174 matrix[j*size+i]=v[i];
00175 else
00176 matrix[j*size+i]=0;
00177 }
00178 }
00179 }
00180
00186 static SGMatrix<T> create_identity_matrix(index_t size, T scale);
00187
00199 static SGMatrix<float64_t> create_centering_matrix(index_t size);
00200
00201 #ifdef HAVE_LAPACK
00202
00210 static SGVector<float64_t> compute_eigenvectors(
00211 SGMatrix<float64_t> matrix);
00212
00220 static double* compute_eigenvectors(double* matrix, int n, int m);
00221
00232 void compute_few_eigenvectors(double* matrix_, double*& eigenvalues, double*& eigenvectors,
00233 int n, int il, int iu);
00234 #endif
00235
00243 static SGMatrix<float64_t> matrix_multiply(
00244 SGMatrix<float64_t> A, SGMatrix<float64_t> B,
00245 bool transpose_A=false, bool transpose_B=false,
00246 float64_t scale=1.0);
00247 #ifdef HAVE_LAPACK
00248
00249 static void inverse(SGMatrix<float64_t> matrix);
00250
00254 static float64_t* pinv(
00255 float64_t* matrix, int32_t rows, int32_t cols,
00256 float64_t* target=NULL);
00257
00258 #endif
00259
00261 static inline float64_t trace(
00262 float64_t* mat, int32_t cols, int32_t rows)
00263 {
00264 float64_t trace=0;
00265 for (int32_t i=0; i<rows; i++)
00266 trace+=mat[i*cols+i];
00267 return trace;
00268 }
00269
00271 static T* get_row_sum(T* matrix, int32_t m, int32_t n)
00272 {
00273 T* rowsums=SG_CALLOC(T, n);
00274
00275 for (int32_t i=0; i<n; i++)
00276 {
00277 for (int32_t j=0; j<m; j++)
00278 rowsums[i]+=matrix[j+int64_t(i)*m];
00279 }
00280 return rowsums;
00281 }
00282
00284 static T* get_column_sum(T* matrix, int32_t m, int32_t n)
00285 {
00286 T* colsums=SG_CALLOC(T, m);
00287
00288 for (int32_t i=0; i<n; i++)
00289 {
00290 for (int32_t j=0; j<m; j++)
00291 colsums[j]+=matrix[j+int64_t(i)*m];
00292 }
00293 return colsums;
00294 }
00295
00297 void center()
00298 {
00299 center_matrix(matrix, num_rows, num_cols);
00300 }
00301
00303 static void center_matrix(T* matrix, int32_t m, int32_t n);
00304
00306 void remove_column_mean();
00307
00309 void display_matrix(const char* name="matrix") const;
00310
00312 static void display_matrix(
00313 const T* matrix, int32_t rows, int32_t cols,
00314 const char* name="matrix", const char* prefix="");
00315
00317 static void display_matrix(
00318 const SGMatrix<T> matrix, const char* name="matrix",
00319 const char* prefix="");
00320
00332 static SGMatrix<T> get_allocated_matrix(index_t num_rows,
00333 index_t num_cols, SGMatrix<T> pre_allocated=SGMatrix<T>());
00334
00335 protected:
00337 virtual void copy_data(const SGReferencedData &orig)
00338 {
00339 matrix=((SGMatrix*)(&orig))->matrix;
00340 num_rows=((SGMatrix*)(&orig))->num_rows;
00341 num_cols=((SGMatrix*)(&orig))->num_cols;
00342 }
00343
00345 virtual void init_data()
00346 {
00347 matrix=NULL;
00348 num_rows=0;
00349 num_cols=0;
00350 }
00351
00353 virtual void free_data()
00354 {
00355 SG_FREE(matrix);
00356 matrix=NULL;
00357 num_rows=0;
00358 num_cols=0;
00359 }
00360
00361 public:
00363 T* matrix;
00365 index_t num_rows;
00367 index_t num_cols;
00368 };
00369 }
00370 #endif // __SGMATRIX_H__