ViennaCL - The Vienna Computing Library  1.5.0
viennacl/generator/helpers.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_GENERATOR_GENERATE_UTILS_HPP
00002 #define VIENNACL_GENERATOR_GENERATE_UTILS_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2013, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00026 #include <set>
00027 
00028 #ifdef __APPLE__
00029 #include <OpenCL/cl.h>
00030 #else
00031 #include "CL/cl.h"
00032 #endif
00033 
00034 #include "viennacl/forwards.h"
00035 #include "viennacl/scheduler/forwards.h"
00036 
00037 #include "viennacl/generator/utils.hpp"
00038 #include "viennacl/generator/forwards.h"
00039 
00040 namespace viennacl{
00041 
00042   namespace generator{
00043 
00044     namespace detail{
00045 
00047       static std::string generate_value_kernel_argument(std::string const & scalartype, std::string const & name){
00048         return scalartype + ' ' + name + ",";
00049       }
00050 
00052       static std::string generate_pointer_kernel_argument(std::string const & address_space, std::string const & scalartype, std::string const & name){
00053         return address_space +  " " + scalartype + "* " + name + ",";
00054       }
00055 
00057       inline const char * generate(viennacl::scheduler::operation_node_type type){
00058         // unary expression
00059         switch(type){
00060           case viennacl::scheduler::OPERATION_UNARY_ABS_TYPE : return "abs";
00061           case viennacl::scheduler::OPERATION_UNARY_TRANS_TYPE : return "trans";
00062           case viennacl::scheduler::OPERATION_BINARY_ASSIGN_TYPE : return "=";
00063           case viennacl::scheduler::OPERATION_BINARY_INPLACE_ADD_TYPE : return "+=";
00064           case viennacl::scheduler::OPERATION_BINARY_INPLACE_SUB_TYPE : return "-=";
00065           case viennacl::scheduler::OPERATION_BINARY_ADD_TYPE : return "+";
00066           case viennacl::scheduler::OPERATION_BINARY_SUB_TYPE : return "-";
00067           case viennacl::scheduler::OPERATION_BINARY_MULT_TYPE : return "*";
00068           case viennacl::scheduler::OPERATION_BINARY_DIV_TYPE : return "/";
00069           case viennacl::scheduler::OPERATION_BINARY_INNER_PROD_TYPE : return "iprod";
00070           case viennacl::scheduler::OPERATION_BINARY_MAT_MAT_PROD_TYPE : return "mmprod";
00071           case viennacl::scheduler::OPERATION_BINARY_MAT_VEC_PROD_TYPE : return "mvprod";
00072           case viennacl::scheduler::OPERATION_BINARY_ACCESS_TYPE : return "[]";
00073           default : throw "not implemented";
00074         }
00075       }
00076 
00078       inline bool is_binary_leaf_operator(viennacl::scheduler::operation_node_type const & op_type) {
00079         return op_type == viennacl::scheduler::OPERATION_BINARY_INNER_PROD_TYPE
00080              ||op_type == viennacl::scheduler::OPERATION_BINARY_MAT_VEC_PROD_TYPE
00081              ||op_type == viennacl::scheduler::OPERATION_BINARY_MAT_MAT_PROD_TYPE;
00082       }
00083 
00085       inline bool is_arithmetic_operator(viennacl::scheduler::operation_node_type const & op_type) {
00086         return op_type == viennacl::scheduler::OPERATION_BINARY_ASSIGN_TYPE
00087              ||op_type == viennacl::scheduler::OPERATION_BINARY_ADD_TYPE
00088              ||op_type == viennacl::scheduler::OPERATION_BINARY_DIV_TYPE
00089              ||op_type == viennacl::scheduler::OPERATION_BINARY_ELEMENT_DIV_TYPE
00090              ||op_type == viennacl::scheduler::OPERATION_BINARY_ELEMENT_PROD_TYPE
00091              ||op_type == viennacl::scheduler::OPERATION_BINARY_INPLACE_ADD_TYPE
00092              ||op_type == viennacl::scheduler::OPERATION_BINARY_INPLACE_SUB_TYPE
00093 //                 ||op_type == viennacl::scheduler::OPERATION_BINARY_INPLACE_DIV_TYPE
00094 //                ||op_type == viennacl::scheduler::OPERATION_BINARY_INPLACE_MULT_TYPE
00095             ||op_type == viennacl::scheduler::OPERATION_BINARY_MULT_TYPE
00096             ||op_type == viennacl::scheduler::OPERATION_BINARY_SUB_TYPE;
00097 
00098       }
00099 
00101       template<class Fun>
00102       static void traverse(viennacl::scheduler::statement const & statement, viennacl::scheduler::statement_node const & root_node, Fun const & fun, bool recurse_binary_leaf /* see forwards.h for default argument */){
00103 
00104         if(root_node.op.type_family==viennacl::scheduler::OPERATION_UNARY_TYPE_FAMILY)
00105         {
00106           //Self:
00107           fun(&statement, &root_node, PARENT_NODE_TYPE);
00108 
00109           //Lhs:
00110           fun.call_before_expansion();
00111           if(root_node.lhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00112               traverse(statement, statement.array()[root_node.lhs.node_index], fun, recurse_binary_leaf);
00113           fun(&statement, &root_node, LHS_NODE_TYPE);
00114           fun.call_after_expansion();
00115         }
00116         else if(root_node.op.type_family==viennacl::scheduler::OPERATION_BINARY_TYPE_FAMILY)
00117         {
00118           bool deep_recursion = recurse_binary_leaf || !is_binary_leaf_operator(root_node.op.type);
00119 
00120           fun.call_before_expansion();
00121 
00122           //Lhs:
00123           if(deep_recursion){
00124             if(root_node.lhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00125               traverse(statement, statement.array()[root_node.lhs.node_index], fun, recurse_binary_leaf);
00126             fun(&statement, &root_node, LHS_NODE_TYPE);
00127           }
00128 
00129           //Self:
00130           fun(&statement, &root_node, PARENT_NODE_TYPE);
00131 
00132           //Rhs:
00133           if(deep_recursion){
00134             if(root_node.rhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00135               traverse(statement, statement.array()[root_node.rhs.node_index], fun, recurse_binary_leaf);
00136             fun(&statement, &root_node, RHS_NODE_TYPE);
00137           }
00138 
00139           fun.call_after_expansion();
00140 
00141         }
00142       }
00143 
00145       class traversal_functor{
00146         public:
00147           void call_before_expansion() const { }
00148           void call_after_expansion() const { }
00149       };
00150 
00152       class prototype_generation_traversal : public traversal_functor{
00153         private:
00154           std::set<std::string> & already_generated_;
00155           std::string & str_;
00156           unsigned int vector_size_;
00157           mapping_type const & mapping_;
00158         public:
00159           prototype_generation_traversal(std::set<std::string> & already_generated, std::string & str, unsigned int vector_size, mapping_type const & mapping) : already_generated_(already_generated), str_(str), vector_size_(vector_size), mapping_(mapping){ }
00160 
00161           void operator()(viennacl::scheduler::statement const *, viennacl::scheduler::statement_node const * root_node, detail::node_type node_type) const {
00162               if( (node_type==detail::LHS_NODE_TYPE && root_node->lhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00163                 ||(node_type==detail::RHS_NODE_TYPE && root_node->rhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) )
00164                   append_kernel_arguments(already_generated_, str_, vector_size_, *at(mapping_, std::make_pair(root_node,node_type)));
00165           }
00166       };
00167 
00169       class fetch_traversal : public traversal_functor{
00170         private:
00171           std::set<std::string> & fetched_;
00172           std::pair<std::string, std::string> index_string_;
00173           unsigned int vectorization_;
00174           utils::kernel_generation_stream & stream_;
00175           mapping_type const & mapping_;
00176         public:
00177           fetch_traversal(std::set<std::string> & fetched, std::pair<std::string, std::string> const & index, unsigned int vectorization, utils::kernel_generation_stream & stream, mapping_type const & mapping) : fetched_(fetched), index_string_(index), vectorization_(vectorization), stream_(stream), mapping_(mapping){ }
00178 
00179           void operator()(viennacl::scheduler::statement const *, viennacl::scheduler::statement_node const * root_node, detail::node_type node_type) const {
00180             if( (node_type==detail::LHS_NODE_TYPE && root_node->lhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00181               ||(node_type==detail::RHS_NODE_TYPE && root_node->rhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY) )
00182               fetch(index_string_, vectorization_, fetched_, stream_, *at(mapping_, std::make_pair(root_node, node_type)));
00183           }
00184       };
00185 
00190       static void fetch_all_lhs(std::set<std::string> & fetched
00191                                 , viennacl::scheduler::statement const & statement
00192                                 , viennacl::scheduler::statement_node const & root_node
00193                                 , std::pair<std::string, std::string> const & index
00194                                 , vcl_size_t const & vectorization
00195                                 , utils::kernel_generation_stream & stream
00196                                 , detail::mapping_type const & mapping){
00197         if(root_node.lhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00198           detail::traverse(statement, statement.array()[root_node.lhs.node_index], detail::fetch_traversal(fetched, index, static_cast<unsigned int>(vectorization), stream, mapping));
00199         else
00200           detail::fetch(index, static_cast<unsigned int>(vectorization),fetched, stream, *at(mapping, std::make_pair(&root_node,detail::LHS_NODE_TYPE)));
00201 
00202       }
00203 
00208       static void fetch_all_rhs(std::set<std::string> & fetched
00209                                 , viennacl::scheduler::statement const & statement
00210                                 , viennacl::scheduler::statement_node const & root_node
00211                                 , std::pair<std::string, std::string> const & index
00212                                 , vcl_size_t const & vectorization
00213                                 , utils::kernel_generation_stream & stream
00214                                 , detail::mapping_type const & mapping){
00215         if(root_node.rhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00216           detail::traverse(statement, statement.array()[root_node.rhs.node_index], detail::fetch_traversal(fetched, index, static_cast<unsigned int>(vectorization), stream, mapping));
00217         else
00218           detail::fetch(index, static_cast<unsigned int>(vectorization),fetched, stream, *at(mapping, std::make_pair(&root_node,detail::RHS_NODE_TYPE)));
00219 
00220       }
00221 
00222 
00224       class expression_generation_traversal : public traversal_functor{
00225         private:
00226           std::pair<std::string, std::string> index_string_;
00227           int vector_element_;
00228           std::string & str_;
00229           mapping_type const & mapping_;
00230 
00231         public:
00232           expression_generation_traversal(std::pair<std::string, std::string> const & index, int vector_element, std::string & str, mapping_type const & mapping) : index_string_(index), vector_element_(vector_element), str_(str), mapping_(mapping){ }
00233 
00234           void call_before_expansion() const { str_+="("; }
00235           void call_after_expansion() const { str_+=")"; }
00236 
00237           void operator()(viennacl::scheduler::statement const *, viennacl::scheduler::statement_node const * root_node, detail::node_type node_type) const {
00238             if(node_type==PARENT_NODE_TYPE)
00239             {
00240               if(is_binary_leaf_operator(root_node->op.type))
00241                 str_ += generate(index_string_, vector_element_, *at(mapping_, std::make_pair(root_node, node_type)));
00242               else if(is_arithmetic_operator(root_node->op.type))
00243                 str_ += generate(root_node->op.type);
00244             }
00245             else{
00246               if(node_type==LHS_NODE_TYPE){
00247                 if(root_node->lhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00248                   str_ += detail::generate(index_string_,vector_element_, *at(mapping_, std::make_pair(root_node,node_type)));
00249               }
00250               else if(node_type==RHS_NODE_TYPE){
00251                 if(root_node->rhs.type_family!=viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00252                   str_ += detail::generate(index_string_,vector_element_, *at(mapping_, std::make_pair(root_node,node_type)));
00253               }
00254             }
00255           }
00256       };
00257 
00258       static void generate_all_lhs(viennacl::scheduler::statement const & statement
00259                                 , viennacl::scheduler::statement_node const & root_node
00260                                 , std::pair<std::string, std::string> const & index
00261                                 , int vector_element
00262                                 , std::string & str
00263                                 , detail::mapping_type const & mapping){
00264         if(root_node.lhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00265           detail::traverse(statement, statement.array()[root_node.lhs.node_index], detail::expression_generation_traversal(index, vector_element, str, mapping));
00266         else
00267           str += detail::generate(index, vector_element,*at(mapping, std::make_pair(&root_node,detail::LHS_NODE_TYPE)));
00268       }
00269 
00270 
00271       static void generate_all_rhs(viennacl::scheduler::statement const & statement
00272                                 , viennacl::scheduler::statement_node const & root_node
00273                                 , std::pair<std::string, std::string> const & index
00274                                 , int vector_element
00275                                 , std::string & str
00276                                 , detail::mapping_type const & mapping){
00277         if(root_node.rhs.type_family==viennacl::scheduler::COMPOSITE_OPERATION_FAMILY)
00278           detail::traverse(statement, statement.array()[root_node.rhs.node_index], detail::expression_generation_traversal(index, vector_element, str, mapping));
00279         else
00280           str += detail::generate(index, vector_element,*at(mapping, std::make_pair(&root_node,detail::RHS_NODE_TYPE)));
00281       }
00282 
00283     }
00284   }
00285 }
00286 #endif