ViennaCL - The Vienna Computing Library
1.5.0
|
00001 #ifndef VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_ 00002 #define VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_ 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2013, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00025 #include <string> 00026 #include <fstream> 00027 #include <sstream> 00028 #include <cmath> 00029 #include <vector> 00030 #include <map> 00031 00032 #include "viennacl/forwards.h" 00033 #include "viennacl/tools/adapter.hpp" 00034 00035 namespace viennacl 00036 { 00037 namespace tools 00038 { 00039 00046 template <typename LHS, typename RHS, typename OP> 00047 struct MATRIX_SIZE_DEDUCER 00048 { 00049 //Standard case: size1 from lhs, size2 from rhs (fits most cases) 00050 static vcl_size_t size1(LHS & lhs, RHS & /*rhs*/) { return lhs.size1(); } 00051 static vcl_size_t size2(LHS & /*lhs*/, RHS & rhs) { return rhs.size2(); } 00052 }; 00053 00055 //special case: outer vector product: 00056 template <typename ScalarType> 00057 struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<ScalarType>, 00058 const viennacl::vector_base<ScalarType>, 00059 viennacl::op_prod> 00060 { 00061 static vcl_size_t size1(viennacl::vector_base<ScalarType> const & lhs, 00062 viennacl::vector_base<ScalarType> const & /*rhs*/) { return lhs.size(); } 00063 00064 static vcl_size_t size2(viennacl::vector_base<ScalarType> const & /*lhs*/, 00065 viennacl::vector_base<ScalarType> const & rhs) { return rhs.size(); } 00066 }; 00067 00068 00069 //special case: multiplication with a scalar 00070 template <typename LHS, typename RHS, typename OP, typename ScalarType> 00071 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>, 00072 const ScalarType, 00073 viennacl::op_mult> 00074 { 00075 static vcl_size_t size1(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs, 00076 ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); } 00077 00078 static vcl_size_t size2(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs, 00079 ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); } 00080 }; 00081 00082 //special case: multiplication with a scalar 00083 template <typename T, typename F, typename ScalarType> 00084 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T, F>, 00085 const ScalarType, 00086 viennacl::op_mult> 00087 { 00088 static vcl_size_t size1(viennacl::matrix_base<T, F> const & lhs, 00089 ScalarType const & /*rhs*/) { return lhs.size1(); } 00090 00091 static vcl_size_t size2(viennacl::matrix_base<T, F> const & lhs, 00092 ScalarType const & /*rhs*/) { return lhs.size2(); } 00093 }; 00094 00095 00096 //special case: division with a scalar 00097 template <typename LHS, typename RHS, typename OP, typename ScalarType> 00098 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>, 00099 const ScalarType, 00100 viennacl::op_div> 00101 { 00102 static vcl_size_t size1(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs, 00103 ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); } 00104 00105 static vcl_size_t size2(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs, 00106 ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); } 00107 }; 00108 00109 //special case: division with a scalar 00110 template <typename T, typename F, typename ScalarType> 00111 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T, F>, 00112 const ScalarType, 00113 viennacl::op_div> 00114 { 00115 static vcl_size_t size1(viennacl::matrix_base<T, F> const & lhs, 00116 ScalarType const & /*rhs*/) { return lhs.size1(); } 00117 00118 static vcl_size_t size2(viennacl::matrix_base<T, F> const & lhs, 00119 ScalarType const & /*rhs*/) { return lhs.size2(); } 00120 }; 00121 00122 //special case: diagonal from vector 00123 template <typename T> 00124 struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<T>, 00125 const int, 00126 viennacl::op_vector_diag> 00127 { 00128 static vcl_size_t size1(viennacl::vector_base<T> const & lhs, 00129 const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); } 00130 00131 static vcl_size_t size2(viennacl::vector_base<T> const & lhs, 00132 const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); } 00133 }; 00134 00135 00136 00137 00138 00139 00140 00141 00142 //special case: transposed matrix-vector product: Return the number of rows of the matrix 00143 template <typename MatrixType> 00144 struct MATRIX_SIZE_DEDUCER<MatrixType, 00145 MatrixType, 00146 viennacl::op_trans> 00147 { 00148 static vcl_size_t size1(const MatrixType & lhs, 00149 const MatrixType & /*rhs*/) { return lhs.size2(); } 00150 static vcl_size_t size2(const MatrixType & lhs, 00151 const MatrixType & /*rhs*/) { return lhs.size1(); } 00152 }; 00153 00154 // A^T * B 00155 template <typename ScalarType, typename T1, typename F2> 00156 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1, 00157 T1, op_trans>, 00158 const viennacl::matrix_base<ScalarType, F2>, 00159 viennacl::op_mat_mat_prod> 00160 { 00161 static vcl_size_t size1(viennacl::matrix_expression<T1, 00162 T1, 00163 op_trans> const & lhs, 00164 viennacl::matrix_base<ScalarType, F2> const & /*rhs*/) { return lhs.lhs().size2(); } 00165 static vcl_size_t size2(viennacl::matrix_expression<T1, 00166 T1, 00167 op_trans> const & /*lhs*/, 00168 viennacl::matrix_base<ScalarType, F2> const & rhs) { return rhs.size2(); } 00169 }; 00170 00171 00172 // A * B^T 00173 00174 template <typename ScalarType, typename F1, typename T2> 00175 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<ScalarType, F1>, 00176 const viennacl::matrix_expression<T2, 00177 T2, op_trans>, 00178 viennacl::op_mat_mat_prod> 00179 { 00180 static vcl_size_t size1(viennacl::matrix_base<ScalarType, F1> const & lhs, 00181 viennacl::matrix_expression<T2, 00182 T2, 00183 op_trans> const & /*rhs*/) { return lhs.size1(); } 00184 static vcl_size_t size2(viennacl::matrix_base<ScalarType, F1> const & /*lhs*/, 00185 viennacl::matrix_expression<T2, 00186 T2, 00187 op_trans> const & rhs) { return rhs.lhs().size1(); } 00188 }; 00189 00190 00191 00192 00193 // A^T * B^T 00194 00195 template <typename T1, typename T2> 00196 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1, 00197 T1, op_trans>, 00198 const viennacl::matrix_expression<T2, 00199 T2, op_trans>, 00200 viennacl::op_mat_mat_prod> 00201 { 00202 typedef viennacl::matrix_expression<T1, T1, op_trans> LHSType; 00203 typedef viennacl::matrix_expression<T2, T2, op_trans> RHSType; 00204 00205 static vcl_size_t size1(LHSType const & lhs, 00206 RHSType const & /*rhs*/) { return lhs.lhs().size2(); } 00207 static vcl_size_t size2(LHSType const & /*lhs*/, 00208 RHSType const & rhs) { return rhs.lhs().size1(); } 00209 }; 00211 } 00212 } 00213 00214 #endif 00215