ViennaCL - The Vienna Computing Library  1.5.0
viennacl/scheduler/execute_vector_assign.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_SCHEDULER_EXECUTE_VECTOR_ASSIGN_HPP
00002 #define VIENNACL_SCHEDULER_EXECUTE_VECTOR_ASSIGN_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2013, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00026 #include "viennacl/forwards.h"
00027 #include "viennacl/scheduler/forwards.h"
00028 #include "viennacl/scheduler/execute_vector_dispatcher.hpp"
00029 
00030 namespace viennacl
00031 {
00032   namespace scheduler
00033   {
00034 
00035     // forward declaration
00036     inline void execute_vector(statement const & s, statement_node const & root_node);
00037 
00038 
00040     inline void execute_vector_composite(statement const & s, statement_node const & root_node)
00041     {
00042       statement::container_type const & expr = s.array();
00043 
00044       statement_node const & leaf = expr[root_node.rhs.node_index];
00045 
00046       if (leaf.op.type  == OPERATION_BINARY_ADD_TYPE || leaf.op.type  == OPERATION_BINARY_SUB_TYPE) // x = (y) +- (z)  where y and z are either vectors or expressions
00047       {
00048         bool flip_sign_z = (leaf.op.type  == OPERATION_BINARY_SUB_TYPE);
00049 
00050         if (   leaf.lhs.type_family == VECTOR_TYPE_FAMILY
00051             && leaf.lhs.type_family == VECTOR_TYPE_FAMILY)
00052         {
00053           lhs_rhs_element u = root_node.lhs;
00054           lhs_rhs_element v = leaf.lhs;
00055           lhs_rhs_element w = leaf.rhs;
00056           switch (root_node.op.type)
00057           {
00058             case OPERATION_BINARY_ASSIGN_TYPE:
00059               detail::avbv(u,
00060                            v, 1.0, 1, false, false,
00061                            w, 1.0, 1, false, flip_sign_z);
00062               break;
00063             case OPERATION_BINARY_INPLACE_ADD_TYPE:
00064               detail::avbv_v(u,
00065                              v, 1.0, 1, false, false,
00066                              w, 1.0, 1, false, flip_sign_z);
00067               break;
00068             case OPERATION_BINARY_INPLACE_SUB_TYPE:
00069               detail::avbv_v(u,
00070                              v, 1.0, 1, false, true,
00071                              w, 1.0, 1, false, !flip_sign_z);
00072               break;
00073             default:
00074               throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00075           }
00076         }
00077         else if (  leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY
00078                 && leaf.rhs.type_family == VECTOR_TYPE_FAMILY) // x = (y) + z, y being a subtree itself, z being a vector
00079         {
00080           statement_node const & y = expr[leaf.lhs.node_index];
00081 
00082           if (y.op.type_family == OPERATION_BINARY_TYPE_FAMILY)
00083           {
00084             // y might be  'v * alpha' or 'v / alpha' with vector v
00085             if (   (y.op.type == OPERATION_BINARY_MULT_TYPE || y.op.type == OPERATION_BINARY_DIV_TYPE)
00086                 &&  y.lhs.type_family == VECTOR_TYPE_FAMILY
00087                 &&  y.rhs.type_family == SCALAR_TYPE_FAMILY)
00088             {
00089               lhs_rhs_element u = root_node.lhs;
00090               lhs_rhs_element v = y.lhs;
00091               lhs_rhs_element w = leaf.rhs;
00092               lhs_rhs_element alpha = y.rhs;
00093 
00094               bool is_division = (y.op.type == OPERATION_BINARY_DIV_TYPE);
00095               switch (root_node.op.type)
00096               {
00097                 case OPERATION_BINARY_ASSIGN_TYPE:
00098                   detail::avbv(u,
00099                                v, alpha, 1, is_division, false,
00100                                w,   1.0, 1, false,       flip_sign_z);
00101                   break;
00102                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00103                   detail::avbv_v(u,
00104                                  v, alpha, 1, is_division, false,
00105                                  w,   1.0, 1, false,       flip_sign_z);
00106                   break;
00107                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00108                   detail::avbv_v(u,
00109                                  v, alpha, 1, is_division, true,
00110                                  w,   1.0, 1, false,       !flip_sign_z);
00111                   break;
00112                 default:
00113                   throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00114               }
00115             }
00116             else // no built-in kernel, we use a temporary.
00117             {
00118               statement_node new_root_y;
00119 
00120               new_root_y.lhs.type_family  = root_node.lhs.type_family;
00121               new_root_y.lhs.subtype      = root_node.lhs.subtype;
00122               new_root_y.lhs.numeric_type = root_node.lhs.numeric_type;
00123               detail::new_vector(new_root_y.lhs, (root_node.lhs.vector_float)->size());
00124 
00125               new_root_y.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00126               new_root_y.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00127 
00128               new_root_y.rhs.type_family  = COMPOSITE_OPERATION_FAMILY;
00129               new_root_y.rhs.subtype      = INVALID_SUBTYPE;
00130               new_root_y.rhs.numeric_type = INVALID_NUMERIC_TYPE;
00131               new_root_y.rhs.node_index   = leaf.lhs.node_index;
00132 
00133               // work on subexpression:
00134               // TODO: Catch exception, free temporary, then rethrow
00135               execute_vector(s, new_root_y);
00136 
00137               // now add:
00138               lhs_rhs_element u = root_node.lhs;
00139               lhs_rhs_element v = new_root_y.lhs;
00140               lhs_rhs_element w = leaf.rhs;
00141               switch (root_node.op.type)
00142               {
00143                 case OPERATION_BINARY_ASSIGN_TYPE:
00144                   detail::avbv(u,
00145                                v, 1.0, 1, false, false,
00146                                w, 1.0, 1, false, flip_sign_z);
00147                   break;
00148                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00149                   detail::avbv_v(u,
00150                                  v, 1.0, 1, false, false,
00151                                  w, 1.0, 1, false, flip_sign_z);
00152                   break;
00153                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00154                   detail::avbv_v(u,
00155                                  v, 1.0, 1, false, true,
00156                                  w, 1.0, 1, false, !flip_sign_z);
00157                   break;
00158                 default:
00159                   throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00160               }
00161 
00162               detail::delete_vector(new_root_y.lhs);
00163             }
00164           }
00165           else
00166             throw statement_not_supported_exception("Cannot deal with unary operations on vectors");
00167 
00168         }
00169         else if (  leaf.lhs.type_family == VECTOR_TYPE_FAMILY
00170                 && leaf.rhs.type_family == COMPOSITE_OPERATION_FAMILY) // x = y + (z), y being vector, z being a subtree itself
00171         {
00172           statement_node const & z = expr[leaf.rhs.node_index];
00173 
00174           if (z.op.type_family == OPERATION_BINARY_TYPE_FAMILY)
00175           {
00176             // z might be  'v * alpha' or 'v / alpha' with vector v
00177             if (   (z.op.type == OPERATION_BINARY_MULT_TYPE || z.op.type == OPERATION_BINARY_DIV_TYPE)
00178                 &&  z.lhs.type_family == VECTOR_TYPE_FAMILY
00179                 &&  z.rhs.type_family == SCALAR_TYPE_FAMILY)
00180             {
00181               lhs_rhs_element u = root_node.lhs;
00182               lhs_rhs_element v = leaf.rhs;
00183               lhs_rhs_element w = z.lhs;
00184               lhs_rhs_element beta = z.rhs;
00185 
00186               bool is_division = (z.op.type == OPERATION_BINARY_DIV_TYPE);
00187               switch (root_node.op.type)
00188               {
00189                 case OPERATION_BINARY_ASSIGN_TYPE:
00190                   detail::avbv(u,
00191                                v,  1.0, 1, false, false,
00192                                w, beta, 1, is_division, flip_sign_z);
00193                   break;
00194                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00195                   detail::avbv_v(u,
00196                                  v,  1.0, 1, false, false,
00197                                  w, beta, 1, is_division, flip_sign_z);
00198                   break;
00199                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00200                   detail::avbv_v(u,
00201                                  v,  1.0, 1, false, true,
00202                                  w, beta, 1, is_division, !flip_sign_z);
00203                   break;
00204                 default:
00205                   throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00206               }
00207             }
00208             else // no built-in kernel, we use a temporary.
00209             {
00210               statement_node new_root_z;
00211 
00212               new_root_z.lhs.type_family  = root_node.lhs.type_family;
00213               new_root_z.lhs.subtype      = root_node.lhs.subtype;
00214               new_root_z.lhs.numeric_type = root_node.lhs.numeric_type;
00215               detail::new_vector(new_root_z.lhs, (root_node.lhs.vector_float)->size());
00216 
00217               new_root_z.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00218               new_root_z.op.type   = OPERATION_BINARY_ASSIGN_TYPE;
00219 
00220               new_root_z.rhs.type_family  = COMPOSITE_OPERATION_FAMILY;
00221               new_root_z.rhs.subtype      = INVALID_SUBTYPE;
00222               new_root_z.rhs.numeric_type = INVALID_NUMERIC_TYPE;
00223               new_root_z.rhs.node_index   = leaf.rhs.node_index;
00224 
00225               // work on subexpression:
00226               // TODO: Catch exception, free temporary, then rethrow
00227               execute_vector(s, new_root_z);
00228 
00229               // now add:
00230               lhs_rhs_element u = root_node.lhs;
00231               lhs_rhs_element v = leaf.lhs;
00232               lhs_rhs_element w = new_root_z.lhs;
00233               switch (root_node.op.type)
00234               {
00235                 case OPERATION_BINARY_ASSIGN_TYPE:
00236                   detail::avbv(u,
00237                                v, 1.0, 1, false, false,
00238                                w, 1.0, 1, false, flip_sign_z);
00239                   break;
00240                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00241                   detail::avbv_v(u,
00242                                  v, 1.0, 1, false, false,
00243                                  w, 1.0, 1, false, flip_sign_z);
00244                   break;
00245                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00246                   detail::avbv_v(u,
00247                                  v, 1.0, 1, false, true,
00248                                  w, 1.0, 1, false, !flip_sign_z);
00249                   break;
00250                 default:
00251                   throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00252               }
00253 
00254               detail::delete_vector(new_root_z.lhs);
00255             }
00256           }
00257           else
00258             throw statement_not_supported_exception("Cannot deal with unary operations on vectors");
00259 
00260         }
00261         else if (  leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY
00262                 && leaf.rhs.type_family == COMPOSITE_OPERATION_FAMILY) // x = (y) + (z), y and z being subtrees
00263         {
00264           statement_node const & y = expr[leaf.lhs.node_index];
00265           statement_node const & z = expr[leaf.rhs.node_index];
00266 
00267           if (   y.op.type_family == OPERATION_BINARY_TYPE_FAMILY
00268               && z.op.type_family == OPERATION_BINARY_TYPE_FAMILY)
00269           {
00270             // z might be  'v * alpha' or 'v / alpha' with vector v
00271             if (   (y.op.type == OPERATION_BINARY_MULT_TYPE || y.op.type == OPERATION_BINARY_DIV_TYPE)
00272                 &&  y.lhs.type_family == VECTOR_TYPE_FAMILY
00273                 &&  y.rhs.type_family == SCALAR_TYPE_FAMILY
00274                 && (z.op.type == OPERATION_BINARY_MULT_TYPE || z.op.type == OPERATION_BINARY_DIV_TYPE)
00275                 &&  z.lhs.type_family == VECTOR_TYPE_FAMILY
00276                 &&  z.rhs.type_family == SCALAR_TYPE_FAMILY)
00277             {
00278               lhs_rhs_element u = root_node.lhs;
00279               lhs_rhs_element v = y.lhs;
00280               lhs_rhs_element w = z.lhs;
00281               lhs_rhs_element alpha = y.rhs;
00282               lhs_rhs_element beta  = z.rhs;
00283 
00284               bool is_division_y = (y.op.type == OPERATION_BINARY_DIV_TYPE);
00285               bool is_division_z = (z.op.type == OPERATION_BINARY_DIV_TYPE);
00286               switch (root_node.op.type)
00287               {
00288                 case OPERATION_BINARY_ASSIGN_TYPE:
00289                   detail::avbv(u,
00290                                v, alpha, 1, is_division_y, false,
00291                                w,  beta, 1, is_division_z, flip_sign_z);
00292                   break;
00293                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00294                   detail::avbv_v(u,
00295                                  v, alpha, 1, is_division_y, false,
00296                                  w,  beta, 1, is_division_z, flip_sign_z);
00297                   break;
00298                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00299                   detail::avbv_v(u,
00300                                  v, alpha, 1, is_division_y, true,
00301                                  w,  beta, 1, is_division_z, !flip_sign_z);
00302                   break;
00303                 default:
00304                   throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00305               }
00306             }
00307             else // no built-in kernel, we use a temporary.
00308             {
00309               statement_node new_root_y;
00310 
00311               new_root_y.lhs.type_family  = root_node.lhs.type_family;
00312               new_root_y.lhs.subtype      = root_node.lhs.subtype;
00313               new_root_y.lhs.numeric_type = root_node.lhs.numeric_type;
00314               detail::new_vector(new_root_y.lhs, (root_node.lhs.vector_float)->size());
00315 
00316               new_root_y.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00317               new_root_y.op.type   = OPERATION_BINARY_ASSIGN_TYPE;
00318 
00319               new_root_y.rhs.type_family  = COMPOSITE_OPERATION_FAMILY;
00320               new_root_y.rhs.subtype      = INVALID_SUBTYPE;
00321               new_root_y.rhs.numeric_type = INVALID_NUMERIC_TYPE;
00322               new_root_y.rhs.node_index   = leaf.lhs.node_index;
00323 
00324               // work on subexpression:
00325               // TODO: Catch exception, free temporary, then rethrow
00326               execute_vector(s, new_root_y);
00327 
00328               statement_node new_root_z;
00329 
00330               new_root_z.lhs.type_family  = root_node.lhs.type_family;
00331               new_root_z.lhs.subtype      = root_node.lhs.subtype;
00332               new_root_z.lhs.numeric_type = root_node.lhs.numeric_type;
00333               detail::new_vector(new_root_z.lhs, (root_node.lhs.vector_float)->size());
00334 
00335               new_root_z.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00336               new_root_z.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00337 
00338               new_root_z.rhs.type_family  = COMPOSITE_OPERATION_FAMILY;
00339               new_root_z.rhs.subtype      = INVALID_SUBTYPE;
00340               new_root_z.rhs.numeric_type = INVALID_NUMERIC_TYPE;
00341               new_root_z.rhs.node_index   = leaf.rhs.node_index;
00342 
00343               // work on subexpression:
00344               // TODO: Catch exception, free temporaries, then rethrow
00345               execute_vector(s, new_root_z);
00346 
00347               // now add:
00348               lhs_rhs_element u = root_node.lhs;
00349               lhs_rhs_element v = new_root_y.lhs;
00350               lhs_rhs_element w = new_root_z.lhs;
00351 
00352               switch (root_node.op.type)
00353               {
00354                 case OPERATION_BINARY_ASSIGN_TYPE:
00355                   detail::avbv(u,
00356                                v, 1.0, 1, false, false,
00357                                w, 1.0, 1, false, flip_sign_z);
00358                   break;
00359                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00360                   detail::avbv_v(u,
00361                                  v, 1.0, 1, false, false,
00362                                  w, 1.0, 1, false, flip_sign_z);
00363                   break;
00364                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00365                   detail::avbv_v(u,
00366                                  v, 1.0, 1, false, true,
00367                                  w, 1.0, 1, false, !flip_sign_z);
00368                   break;
00369                 default:
00370                   throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00371               }
00372 
00373               detail::delete_vector(new_root_y.lhs);
00374               detail::delete_vector(new_root_z.lhs);
00375             }
00376           }
00377           else
00378             throw statement_not_supported_exception("Cannot deal with unary operations on vectors");
00379         }
00380         else
00381           throw statement_not_supported_exception("Cannot deal with addition of vectors");
00382       }
00383       else if (leaf.op.type  == OPERATION_BINARY_MULT_TYPE || leaf.op.type  == OPERATION_BINARY_DIV_TYPE) // x = y * / alpha;
00384       {
00385         if (   leaf.lhs.type_family == VECTOR_TYPE_FAMILY
00386             && leaf.rhs.type_family == SCALAR_TYPE_FAMILY)
00387         {
00388           lhs_rhs_element u = root_node.lhs;
00389           lhs_rhs_element v = leaf.lhs;
00390           lhs_rhs_element alpha = leaf.rhs;
00391 
00392           bool is_division = (leaf.op.type  == OPERATION_BINARY_DIV_TYPE);
00393           switch (root_node.op.type)
00394           {
00395             case OPERATION_BINARY_ASSIGN_TYPE:
00396               detail::av(u,
00397                          v, alpha, 1, is_division, false);
00398               break;
00399             case OPERATION_BINARY_INPLACE_ADD_TYPE:
00400               detail::avbv(u,
00401                            u,   1.0, 1, false,       false,
00402                            v, alpha, 1, is_division, false);
00403               break;
00404             case OPERATION_BINARY_INPLACE_SUB_TYPE:
00405               detail::avbv(u,
00406                            u,   1.0, 1, false,       false,
00407                            v, alpha, 1, is_division, true);
00408               break;
00409             default:
00410               throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00411           }
00412 
00413         }
00414         else
00415           throw statement_not_supported_exception("Unsupported binary operator for OPERATION_BINARY_MULT_TYPE || OPERATION_BINARY_DIV_TYPE on leaf node.");
00416       }
00417       else
00418         throw statement_not_supported_exception("Unsupported binary operator for vector operations");
00419     }
00420 
00422     inline void execute_vector_vector(statement const &, statement_node const & root_node)
00423     {
00424       lhs_rhs_element u = root_node.lhs;
00425       lhs_rhs_element v = root_node.rhs;
00426       switch (root_node.op.type)
00427       {
00428         case OPERATION_BINARY_ASSIGN_TYPE:
00429           detail::av(u,
00430                      v, 1.0, 1, false, false);
00431           break;
00432         case OPERATION_BINARY_INPLACE_ADD_TYPE:
00433           detail::avbv(u,
00434                        u, 1.0, 1, false, false,
00435                        v, 1.0, 1, false, false);
00436           break;
00437         case OPERATION_BINARY_INPLACE_SUB_TYPE:
00438           detail::avbv(u,
00439                        u, 1.0, 1, false, false,
00440                        v, 1.0, 1, false, true);
00441           break;
00442         default:
00443           throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00444       }
00445 
00446     }
00447 
00449     inline void execute_vector(statement const & s, statement_node const & root_node)
00450     {
00451       switch (root_node.rhs.type_family)
00452       {
00453         case COMPOSITE_OPERATION_FAMILY:
00454           execute_vector_composite(s, root_node);
00455           break;
00456         case VECTOR_TYPE_FAMILY:
00457           execute_vector_vector(s, root_node);
00458           break;
00459         default:
00460           throw statement_not_supported_exception("Invalid rvalue encountered in vector assignment");
00461       }
00462     }
00463 
00464 
00465   }
00466 
00467 } //namespace viennacl
00468 
00469 #endif
00470