ViennaCL - The Vienna Computing Library  1.5.0
viennacl/linalg/prod.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_LINALG_PROD_HPP_
00002 #define VIENNACL_LINALG_PROD_HPP_
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2013, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00027 #include "viennacl/forwards.h"
00028 #include "viennacl/tools/tools.hpp"
00029 #include "viennacl/meta/enable_if.hpp"
00030 #include "viennacl/meta/tag_of.hpp"
00031 #include <vector>
00032 #include <map>
00033 
00034 namespace viennacl
00035 {
00036   //
00037   // generic prod function
00038   //   uses tag dispatch to identify which algorithm
00039   //   should be called
00040   //
00041   namespace linalg
00042   {
00043     #ifdef VIENNACL_WITH_MTL4
00044     // ----------------------------------------------------
00045     // mtl4
00046     //
00047     template< typename MatrixT, typename VectorT >
00048     typename viennacl::enable_if< viennacl::is_mtl4< typename viennacl::traits::tag_of< MatrixT >::type >::value,
00049                                   VectorT>::type
00050     prod(MatrixT const& matrix, VectorT const& vector)
00051     {
00052       return VectorT(matrix * vector);
00053     }
00054     #endif
00055 
00056     #ifdef VIENNACL_WITH_EIGEN
00057     // ----------------------------------------------------
00058     // Eigen
00059     //
00060     template< typename MatrixT, typename VectorT >
00061     typename viennacl::enable_if< viennacl::is_eigen< typename viennacl::traits::tag_of< MatrixT >::type >::value,
00062                                   VectorT>::type
00063     prod(MatrixT const& matrix, VectorT const& vector)
00064     {
00065       return matrix * vector;
00066     }
00067     #endif
00068 
00069     #ifdef VIENNACL_WITH_UBLAS
00070     // ----------------------------------------------------
00071     // UBLAS
00072     //
00073     template< typename MatrixT, typename VectorT >
00074     typename viennacl::enable_if< viennacl::is_ublas< typename viennacl::traits::tag_of< MatrixT >::type >::value,
00075                                   VectorT>::type
00076     prod(MatrixT const& matrix, VectorT const& vector)
00077     {
00078       // std::cout << "ublas .. " << std::endl;
00079       return boost::numeric::ublas::prod(matrix, vector);
00080     }
00081     #endif
00082 
00083 
00084     // ----------------------------------------------------
00085     // STL type
00086     //
00087 
00088     // dense matrix-vector product:
00089     template< typename T, typename A1, typename A2, typename VectorT >
00090     VectorT
00091     prod(std::vector< std::vector<T, A1>, A2 > const & matrix, VectorT const& vector)
00092     {
00093       VectorT result(matrix.size());
00094       for (typename std::vector<T, A1>::size_type i=0; i<matrix.size(); ++i)
00095       {
00096         result[i] = 0; //we will not assume that VectorT is initialized to zero
00097         for (typename std::vector<T, A1>::size_type j=0; j<matrix[i].size(); ++j)
00098           result[i] += matrix[i][j] * vector[j];
00099       }
00100       return result;
00101     }
00102 
00103     // sparse matrix-vector product:
00104     template< typename KEY, typename DATA, typename COMPARE, typename AMAP, typename AVEC, typename VectorT >
00105     VectorT
00106     prod(std::vector< std::map<KEY, DATA, COMPARE, AMAP>, AVEC > const& matrix, VectorT const& vector)
00107     {
00108       typedef std::vector< std::map<KEY, DATA, COMPARE, AMAP>, AVEC > MatrixType;
00109 
00110       VectorT result(matrix.size());
00111       for (typename MatrixType::size_type i=0; i<matrix.size(); ++i)
00112       {
00113         result[i] = 0; //we will not assume that VectorT is initialized to zero
00114         for (typename std::map<KEY, DATA, COMPARE, AMAP>::const_iterator row_entries = matrix[i].begin();
00115              row_entries != matrix[i].end();
00116              ++row_entries)
00117           result[i] += row_entries->second * vector[row_entries->first];
00118       }
00119       return result;
00120     }
00121 
00122 
00123     /*template< typename MatrixT, typename VectorT >
00124     VectorT
00125     prod(MatrixT const& matrix, VectorT const& vector,
00126          typename viennacl::enable_if< viennacl::is_stl< typename viennacl::traits::tag_of< MatrixT >::type >::value
00127                                      >::type* dummy = 0)
00128     {
00129       // std::cout << "std .. " << std::endl;
00130       return prod_impl(matrix, vector);
00131     }*/
00132 
00133     // ----------------------------------------------------
00134     // VIENNACL
00135     //
00136 
00137     // standard product:
00138     template< typename NumericT, typename F1, typename F2>
00139     viennacl::matrix_expression< const viennacl::matrix_base<NumericT, F1>,
00140                                  const viennacl::matrix_base<NumericT, F2>,
00141                                  viennacl::op_mat_mat_prod >
00142     prod(viennacl::matrix_base<NumericT, F1> const & A,
00143          viennacl::matrix_base<NumericT, F2> const & B)
00144     {
00145       // std::cout << "viennacl .. " << std::endl;
00146       return viennacl::matrix_expression< const viennacl::matrix_base<NumericT, F1>,
00147                                           const viennacl::matrix_base<NumericT, F2>,
00148                                           viennacl::op_mat_mat_prod >(A, B);
00149     }
00150 
00151     // right factor is transposed:
00152     template< typename NumericT, typename F1, typename F2>
00153     viennacl::matrix_expression< const viennacl::matrix_base<NumericT, F1>,
00154                                  const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>,
00155                                                                    const viennacl::matrix_base<NumericT, F2>,
00156                                                                    op_trans>,
00157                                  viennacl::op_mat_mat_prod >
00158     prod(viennacl::matrix_base<NumericT, F1> const & A,
00159          viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>,
00160                                      const viennacl::matrix_base<NumericT, F2>,
00161                                      op_trans> const & B)
00162     {
00163       // std::cout << "viennacl .. " << std::endl;
00164       return viennacl::matrix_expression< const viennacl::matrix_base<NumericT, F1>,
00165                                           const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>,
00166                                                                             const viennacl::matrix_base<NumericT, F2>,
00167                                                                             op_trans>,
00168                                           viennacl::op_mat_mat_prod >(A, B);
00169     }
00170 
00171     // left factor transposed:
00172     template< typename NumericT, typename F1, typename F2>
00173     viennacl::matrix_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>,
00174                                                                    const viennacl::matrix_base<NumericT, F1>,
00175                                                                    op_trans>,
00176                                  const viennacl::matrix_base<NumericT, F2>,
00177                                  viennacl::op_mat_mat_prod >
00178     prod(viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>,
00179                                      const viennacl::matrix_base<NumericT, F1>,
00180                                      op_trans> const & A,
00181          viennacl::matrix_base<NumericT, F2> const & B)
00182     {
00183       // std::cout << "viennacl .. " << std::endl;
00184       return viennacl::matrix_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>,
00185                                                                             const viennacl::matrix_base<NumericT, F1>,
00186                                                                             op_trans>,
00187                                           const viennacl::matrix_base<NumericT, F2>,
00188                                           viennacl::op_mat_mat_prod >(A, B);
00189     }
00190 
00191 
00192     // both factors transposed:
00193     template< typename NumericT, typename F1, typename F2>
00194     viennacl::matrix_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>,
00195                                                                    const viennacl::matrix_base<NumericT, F1>,
00196                                                                    op_trans>,
00197                                  const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>,
00198                                                                    const viennacl::matrix_base<NumericT, F2>,
00199                                                                    op_trans>,
00200                                  viennacl::op_mat_mat_prod >
00201     prod(viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>,
00202                                      const viennacl::matrix_base<NumericT, F1>,
00203                                      op_trans> const & A,
00204          viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>,
00205                                      const viennacl::matrix_base<NumericT, F2>,
00206                                      op_trans> const & B)
00207     {
00208       // std::cout << "viennacl .. " << std::endl;
00209       return viennacl::matrix_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F1>,
00210                                                                             const viennacl::matrix_base<NumericT, F1>,
00211                                                                             op_trans>,
00212                                           const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F2>,
00213                                                                             const viennacl::matrix_base<NumericT, F2>,
00214                                                                             op_trans>,
00215                                           viennacl::op_mat_mat_prod >(A, B);
00216     }
00217 
00218 
00219 
00220     // matrix-vector product
00221     template< typename NumericT, typename F>
00222     viennacl::vector_expression< const viennacl::matrix_base<NumericT, F>,
00223                                  const viennacl::vector_base<NumericT>,
00224                                  viennacl::op_prod >
00225     prod(viennacl::matrix_base<NumericT, F> const & matrix,
00226          viennacl::vector_base<NumericT> const & vector)
00227     {
00228       // std::cout << "viennacl .. " << std::endl;
00229       return viennacl::vector_expression< const viennacl::matrix_base<NumericT, F>,
00230                                           const viennacl::vector_base<NumericT>,
00231                                           viennacl::op_prod >(matrix, vector);
00232     }
00233 
00234     // transposed matrix-vector product
00235     template< typename NumericT, typename F>
00236     viennacl::vector_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F>,
00237                                                                    const viennacl::matrix_base<NumericT, F>,
00238                                                                    op_trans>,
00239                                  const viennacl::vector_base<NumericT>,
00240                                  viennacl::op_prod >
00241     prod(viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F>,
00242                                      const viennacl::matrix_base<NumericT, F>,
00243                                      op_trans> const & matrix,
00244          viennacl::vector_base<NumericT> const & vector)
00245     {
00246       // std::cout << "viennacl .. " << std::endl;
00247       return viennacl::vector_expression< const viennacl::matrix_expression<const viennacl::matrix_base<NumericT, F>,
00248                                                                             const viennacl::matrix_base<NumericT, F>,
00249                                                                             op_trans>,
00250                                           const viennacl::vector_base<NumericT>,
00251                                           viennacl::op_prod >(matrix, vector);
00252     }
00253 
00254 
00255     template<typename SparseMatrixType, class SCALARTYPE>
00256     typename viennacl::enable_if< viennacl::is_any_sparse_matrix<SparseMatrixType>::value,
00257                                   vector_expression<const SparseMatrixType,
00258                                                     const vector_base<SCALARTYPE>,
00259                                                     op_prod >
00260                                  >::type
00261     prod(const SparseMatrixType & mat,
00262          const vector_base<SCALARTYPE> & vec)
00263     {
00264       return vector_expression<const SparseMatrixType,
00265                                const vector_base<SCALARTYPE>,
00266                                op_prod >(mat, vec);
00267     }
00268 
00269     template< typename SparseMatrixType, typename SCALARTYPE, typename F1>
00270     typename viennacl::enable_if< viennacl::is_any_sparse_matrix<SparseMatrixType>::value,
00271                                   viennacl::matrix_expression<const SparseMatrixType,
00272                                                               const matrix_base < SCALARTYPE, F1 >,
00273                                                               op_prod >
00274                                  >::type
00275     prod(const SparseMatrixType & sp_mat,
00276          const viennacl::matrix_base<SCALARTYPE, F1> & d_mat)
00277     {
00278       return viennacl::matrix_expression<const SparseMatrixType,
00279                                          const viennacl::matrix_base < SCALARTYPE, F1 >,
00280                                          op_prod >(sp_mat, d_mat);
00281     }
00282 
00283     // right factor is transposed
00284     template< typename SparseMatrixType, typename SCALARTYPE, typename F1 >
00285     typename viennacl::enable_if< viennacl::is_any_sparse_matrix<SparseMatrixType>::value,
00286                                   viennacl::matrix_expression< const SparseMatrixType,
00287                                                                const viennacl::matrix_expression<const viennacl::matrix_base<SCALARTYPE, F1>,
00288                                                                                                  const viennacl::matrix_base<SCALARTYPE, F1>,
00289                                                                                                  op_trans>,
00290                                                                viennacl::op_prod >
00291                                   >::type
00292     prod(const SparseMatrixType & A,
00293          viennacl::matrix_expression<const viennacl::matrix_base < SCALARTYPE, F1 >,
00294                                      const viennacl::matrix_base < SCALARTYPE, F1 >,
00295                                      op_trans> const & B)
00296     {
00297       return viennacl::matrix_expression< const SparseMatrixType,
00298                                           const viennacl::matrix_expression<const viennacl::matrix_base < SCALARTYPE, F1 >,
00299                                                                             const viennacl::matrix_base < SCALARTYPE, F1 >,
00300                                                                             op_trans>,
00301                                           viennacl::op_prod >(A, B);
00302     }
00303 
00304     template<typename StructuredMatrixType, class SCALARTYPE>
00305     typename viennacl::enable_if< viennacl::is_any_dense_structured_matrix<StructuredMatrixType>::value,
00306                                   vector_expression<const StructuredMatrixType,
00307                                                     const vector_base<SCALARTYPE>,
00308                                                     op_prod >
00309                                  >::type
00310     prod(const StructuredMatrixType & mat,
00311          const vector_base<SCALARTYPE> & vec)
00312     {
00313       return vector_expression<const StructuredMatrixType,
00314                                const vector_base<SCALARTYPE>,
00315                                op_prod >(mat, vec);
00316     }
00317 
00318   } // end namespace linalg
00319 } // end namespace viennacl
00320 #endif
00321 
00322 
00323 
00324 
00325