ViennaCL - The Vienna Computing Library
1.5.0
|
00001 #ifndef VIENNACL_GENERATOR_GENERATE_UTILS_HPP 00002 #define VIENNACL_GENERATOR_GENERATE_UTILS_HPP 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2013, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00021 00026 #include <set> 00027 00028 #ifdef __APPLE__ 00029 #include <OpenCL/cl.h> 00030 #else 00031 #include "CL/cl.h" 00032 #endif 00033 00034 #include "viennacl/forwards.h" 00035 #include "viennacl/scheduler/forwards.h" 00036 00037 #include "viennacl/generator/utils.hpp" 00038 #include "viennacl/generator/forwards.h" 00039 00040 namespace viennacl{ 00041 00042 namespace generator{ 00043 00044 namespace detail{ 00045 00047 static std::string generate_value_kernel_argument(std::string const & scalartype, std::string const & name){ 00048 return scalartype + ' ' + name + ","; 00049 } 00050 00052 static std::string generate_pointer_kernel_argument(std::string const & address_space, std::string const & scalartype, std::string const & name){ 00053 return address_space + " " + scalartype + "* " + name + ","; 00054 } 00055 00057 inline const char * generate(viennacl::scheduler::operation_node_type type){ 00058 // unary expression 00059 switch(type){ 00060 case viennacl::scheduler::OPERATION_UNARY_ABS_TYPE : return "abs"; 00061 case viennacl::scheduler::OPERATION_UNARY_TRANS_TYPE : return "trans"; 00062 case viennacl::scheduler::OPERATION_BINARY_ASSIGN_TYPE : return "="; 00063 case viennacl::scheduler::OPERATION_BINARY_INPLACE_ADD_TYPE : return "+="; 00064 case viennacl::scheduler::OPERATION_BINARY_INPLACE_SUB_TYPE : return "-="; 00065 case viennacl::scheduler::OPERATION_BINARY_ADD_TYPE : return "+"; 00066 case viennacl::scheduler::OPERATION_BINARY_SUB_TYPE : return "-"; 00067 case viennacl::scheduler::OPERATION_BINARY_MULT_TYPE : return "*"; 00068 case viennacl::scheduler::OPERATION_BINARY_DIV_TYPE : return "/"; 00069 case viennacl::scheduler::OPERATION_BINARY_INNER_PROD_TYPE : return "iprod"; 00070 case viennacl::scheduler::OPERATION_BINARY_MAT_MAT_PROD_TYPE : return "mmprod"; 00071 case viennacl::scheduler::OPERATION_BINARY_MAT_VEC_PROD_TYPE : return "mvprod"; 00072 case viennacl::scheduler::OPERATION_BINARY_ACCESS_TYPE : return "[]"; 00073 default : throw "not implemented"; 00074 } 00075 } 00076 00078 inline bool is_binary_leaf_operator(viennacl::scheduler::operation_node_type const & op_type) { 00079 return op_type == viennacl::scheduler::OPERATION_BINARY_INNER_PROD_TYPE 00080 ||op_type == viennacl::scheduler::OPERATION_BINARY_MAT_VEC_PROD_TYPE 00081 ||op_type == viennacl::scheduler::OPERATION_BINARY_MAT_MAT_PROD_TYPE; 00082 } 00083 00085 inline bool is_arithmetic_operator(viennacl::scheduler::operation_node_type const & op_type) { 00086 return op_type == viennacl::scheduler::OPERATION_BINARY_ASSIGN_TYPE 00087 ||op_type == viennacl::scheduler::OPERATION_BINARY_ADD_TYPE 00088 ||op_type == viennacl::scheduler::OPERATION_BINARY_DIV_TYPE 00089 ||op_type == viennacl::scheduler::OPERATION_BINARY_ELEMENT_DIV_TYPE 00090 ||op_type == viennacl::scheduler::OPERATION_BINARY_ELEMENT_PROD_TYPE 00091 ||op_type == viennacl::scheduler::OPERATION_BINARY_INPLACE_ADD_TYPE 00092 ||op_type == viennacl::scheduler::OPERATION_BINARY_INPLACE_SUB_TYPE 00093 // ||op_type == viennacl::scheduler::OPERATION_BINARY_INPLACE_DIV_TYPE 00094 // ||op_type == viennacl::scheduler::OPERATION_BINARY_INPLACE_MULT_TYPE 00095 ||op_type == viennacl::scheduler::OPERATION_BINARY_MULT_TYPE 00096 ||op_type == viennacl::scheduler::OPERATION_BINARY_SUB_TYPE; 00097 00098 } 00099 00101 template<class Fun> 00102 static void traverse(viennacl::scheduler::statement const & statement, viennacl::scheduler::statement_node const & root_node, Fun const & fun, bool recurse_binary_leaf /* see forwards.h for default argument */){ 00103 00104 if(root_node.op.type_family==viennacl::scheduler::OPERATION_UNARY_TYPE_FAMILY) 00105 { 00106 //Self: 00107 fun(&statement, &root_node, PARENT_NODE_TYPE); 00108 00109 //Lhs: 00110 fun.call_before_expansion(); 00111 if(root_node.lhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00112 traverse(statement, statement.array()[root_node.lhs.node_index], fun, recurse_binary_leaf); 00113 fun(&statement, &root_node, LHS_NODE_TYPE); 00114 fun.call_after_expansion(); 00115 } 00116 else if(root_node.op.type_family==viennacl::scheduler::OPERATION_BINARY_TYPE_FAMILY) 00117 { 00118 bool deep_recursion = recurse_binary_leaf || !is_binary_leaf_operator(root_node.op.type); 00119 00120 fun.call_before_expansion(); 00121 00122 //Lhs: 00123 if(deep_recursion){ 00124 if(root_node.lhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00125 traverse(statement, statement.array()[root_node.lhs.node_index], fun, recurse_binary_leaf); 00126 fun(&statement, &root_node, LHS_NODE_TYPE); 00127 } 00128 00129 //Self: 00130 fun(&statement, &root_node, PARENT_NODE_TYPE); 00131 00132 //Rhs: 00133 if(deep_recursion){ 00134 if(root_node.rhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00135 traverse(statement, statement.array()[root_node.rhs.node_index], fun, recurse_binary_leaf); 00136 fun(&statement, &root_node, RHS_NODE_TYPE); 00137 } 00138 00139 fun.call_after_expansion(); 00140 00141 } 00142 } 00143 00145 class traversal_functor{ 00146 public: 00147 void call_before_expansion() const { } 00148 void call_after_expansion() const { } 00149 }; 00150 00152 class prototype_generation_traversal : public traversal_functor{ 00153 private: 00154 std::set<std::string> & already_generated_; 00155 std::string & str_; 00156 unsigned int vector_size_; 00157 mapping_type const & mapping_; 00158 public: 00159 prototype_generation_traversal(std::set<std::string> & already_generated, std::string & str, unsigned int vector_size, mapping_type const & mapping) : already_generated_(already_generated), str_(str), vector_size_(vector_size), mapping_(mapping){ } 00160 00161 void operator()(viennacl::scheduler::statement const *, viennacl::scheduler::statement_node const * root_node, detail::node_type node_type) const { 00162 if( (node_type==detail::LHS_NODE_TYPE && root_node->lhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00163 ||(node_type==detail::RHS_NODE_TYPE && root_node->rhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) ) 00164 append_kernel_arguments(already_generated_, str_, vector_size_, *at(mapping_, std::make_pair(root_node,node_type))); 00165 } 00166 }; 00167 00169 class fetch_traversal : public traversal_functor{ 00170 private: 00171 std::set<std::string> & fetched_; 00172 std::pair<std::string, std::string> index_string_; 00173 unsigned int vectorization_; 00174 utils::kernel_generation_stream & stream_; 00175 mapping_type const & mapping_; 00176 public: 00177 fetch_traversal(std::set<std::string> & fetched, std::pair<std::string, std::string> const & index, unsigned int vectorization, utils::kernel_generation_stream & stream, mapping_type const & mapping) : fetched_(fetched), index_string_(index), vectorization_(vectorization), stream_(stream), mapping_(mapping){ } 00178 00179 void operator()(viennacl::scheduler::statement const *, viennacl::scheduler::statement_node const * root_node, detail::node_type node_type) const { 00180 if( (node_type==detail::LHS_NODE_TYPE && root_node->lhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00181 ||(node_type==detail::RHS_NODE_TYPE && root_node->rhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) ) 00182 fetch(index_string_, vectorization_, fetched_, stream_, *at(mapping_, std::make_pair(root_node, node_type))); 00183 } 00184 }; 00185 00190 static void fetch_all_lhs(std::set<std::string> & fetched 00191 , viennacl::scheduler::statement const & statement 00192 , viennacl::scheduler::statement_node const & root_node 00193 , std::pair<std::string, std::string> const & index 00194 , vcl_size_t const & vectorization 00195 , utils::kernel_generation_stream & stream 00196 , detail::mapping_type const & mapping){ 00197 if(root_node.lhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00198 detail::traverse(statement, statement.array()[root_node.lhs.node_index], detail::fetch_traversal(fetched, index, static_cast<unsigned int>(vectorization), stream, mapping)); 00199 else 00200 detail::fetch(index, static_cast<unsigned int>(vectorization),fetched, stream, *at(mapping, std::make_pair(&root_node,detail::LHS_NODE_TYPE))); 00201 00202 } 00203 00208 static void fetch_all_rhs(std::set<std::string> & fetched 00209 , viennacl::scheduler::statement const & statement 00210 , viennacl::scheduler::statement_node const & root_node 00211 , std::pair<std::string, std::string> const & index 00212 , vcl_size_t const & vectorization 00213 , utils::kernel_generation_stream & stream 00214 , detail::mapping_type const & mapping){ 00215 if(root_node.rhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00216 detail::traverse(statement, statement.array()[root_node.rhs.node_index], detail::fetch_traversal(fetched, index, static_cast<unsigned int>(vectorization), stream, mapping)); 00217 else 00218 detail::fetch(index, static_cast<unsigned int>(vectorization),fetched, stream, *at(mapping, std::make_pair(&root_node,detail::RHS_NODE_TYPE))); 00219 00220 } 00221 00222 00224 class expression_generation_traversal : public traversal_functor{ 00225 private: 00226 std::pair<std::string, std::string> index_string_; 00227 int vector_element_; 00228 std::string & str_; 00229 mapping_type const & mapping_; 00230 00231 public: 00232 expression_generation_traversal(std::pair<std::string, std::string> const & index, int vector_element, std::string & str, mapping_type const & mapping) : index_string_(index), vector_element_(vector_element), str_(str), mapping_(mapping){ } 00233 00234 void call_before_expansion() const { str_+="("; } 00235 void call_after_expansion() const { str_+=")"; } 00236 00237 void operator()(viennacl::scheduler::statement const *, viennacl::scheduler::statement_node const * root_node, detail::node_type node_type) const { 00238 if(node_type==PARENT_NODE_TYPE) 00239 { 00240 if(is_binary_leaf_operator(root_node->op.type)) 00241 str_ += generate(index_string_, vector_element_, *at(mapping_, std::make_pair(root_node, node_type))); 00242 else if(is_arithmetic_operator(root_node->op.type)) 00243 str_ += generate(root_node->op.type); 00244 } 00245 else{ 00246 if(node_type==LHS_NODE_TYPE){ 00247 if(root_node->lhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00248 str_ += detail::generate(index_string_,vector_element_, *at(mapping_, std::make_pair(root_node,node_type))); 00249 } 00250 else if(node_type==RHS_NODE_TYPE){ 00251 if(root_node->rhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00252 str_ += detail::generate(index_string_,vector_element_, *at(mapping_, std::make_pair(root_node,node_type))); 00253 } 00254 } 00255 } 00256 }; 00257 00258 static void generate_all_lhs(viennacl::scheduler::statement const & statement 00259 , viennacl::scheduler::statement_node const & root_node 00260 , std::pair<std::string, std::string> const & index 00261 , int vector_element 00262 , std::string & str 00263 , detail::mapping_type const & mapping){ 00264 if(root_node.lhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00265 detail::traverse(statement, statement.array()[root_node.lhs.node_index], detail::expression_generation_traversal(index, vector_element, str, mapping)); 00266 else 00267 str += detail::generate(index, vector_element,*at(mapping, std::make_pair(&root_node,detail::LHS_NODE_TYPE))); 00268 } 00269 00270 00271 static void generate_all_rhs(viennacl::scheduler::statement const & statement 00272 , viennacl::scheduler::statement_node const & root_node 00273 , std::pair<std::string, std::string> const & index 00274 , int vector_element 00275 , std::string & str 00276 , detail::mapping_type const & mapping){ 00277 if(root_node.rhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) 00278 detail::traverse(statement, statement.array()[root_node.rhs.node_index], detail::expression_generation_traversal(index, vector_element, str, mapping)); 00279 else 00280 str += detail::generate(index, vector_element,*at(mapping, std::make_pair(&root_node,detail::RHS_NODE_TYPE))); 00281 } 00282 00283 } 00284 } 00285 } 00286 #endif