ViennaCL - The Vienna Computing Library  1.5.0
viennacl/scheduler/execute_matrix.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_SCHEDULER_EXECUTE_MATRIX_HPP
00002 #define VIENNACL_SCHEDULER_EXECUTE_MATRIX_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2013, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00026 #include "viennacl/forwards.h"
00027 #include "viennacl/scheduler/forwards.h"
00028 #include "viennacl/scheduler/execute_matrix_dispatcher.hpp"
00029 
00030 namespace viennacl
00031 {
00032   namespace scheduler
00033   {
00034     inline void execute_matrix(statement const & s, statement_node const & root_node);
00035 
00036 
00038     inline void execute_matrix_composite(statement const & s, statement_node const & root_node)
00039     {
00040       statement::container_type const & expr = s.array();
00041 
00042       statement_node const & leaf = expr[root_node.rhs.node_index];
00043 
00044       if (leaf.op.type  == OPERATION_BINARY_ADD_TYPE || leaf.op.type  == OPERATION_BINARY_SUB_TYPE) // x = (y) +- (z)  where y and z are either matrices or expressions
00045       {
00046         bool flip_sign_z = (leaf.op.type  == OPERATION_BINARY_SUB_TYPE);
00047 
00048         if ( leaf.lhs.type_family == MATRIX_ROW_TYPE_FAMILY || leaf.lhs.type_family == MATRIX_COL_TYPE_FAMILY )
00049         {
00050           lhs_rhs_element u = root_node.lhs;
00051           lhs_rhs_element v = leaf.lhs;
00052           lhs_rhs_element w = leaf.rhs;
00053           switch (root_node.op.type)
00054           {
00055             case OPERATION_BINARY_ASSIGN_TYPE:
00056               detail::ambm(u,
00057                            v, 1.0, 1, false, false,
00058                            w, 1.0, 1, false, flip_sign_z);
00059               break;
00060             case OPERATION_BINARY_INPLACE_ADD_TYPE:
00061               detail::ambm_m(u,
00062                              v, 1.0, 1, false, false,
00063                              w, 1.0, 1, false, flip_sign_z);
00064               break;
00065             case OPERATION_BINARY_INPLACE_SUB_TYPE:
00066               detail::ambm_m(u,
00067                              v, 1.0, 1, false, true,
00068                              w, 1.0, 1, false, !flip_sign_z);
00069               break;
00070             default:
00071               throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)");
00072           }
00073         }
00074         else if (  leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY
00075                 && (leaf.rhs.type_family == MATRIX_COL_TYPE_FAMILY || leaf.rhs.type_family == MATRIX_ROW_TYPE_FAMILY) ) // x = (y) + z, y being a subtree itself, z being a matrix
00076         {
00077           statement_node const & y = expr[leaf.lhs.node_index];
00078 
00079           if (y.op.type_family == OPERATION_BINARY_TYPE_FAMILY)
00080           {
00081             // y might be  'A * alpha' or 'A / alpha' with matrix A
00082             if (   (y.op.type == OPERATION_BINARY_MULT_TYPE     || y.op.type == OPERATION_BINARY_DIV_TYPE)
00083                 && (y.lhs.type_family == MATRIX_ROW_TYPE_FAMILY || y.lhs.type_family == MATRIX_COL_TYPE_FAMILY)
00084                 && (y.rhs.type_family == SCALAR_TYPE_FAMILY     || y.rhs.type_family == HOST_SCALAR_TYPE_FAMILY))
00085             {
00086               lhs_rhs_element u = root_node.lhs;
00087               lhs_rhs_element v = y.lhs;
00088               lhs_rhs_element w = leaf.rhs;
00089               lhs_rhs_element alpha = y.rhs;
00090 
00091               bool is_division = (y.op.type == OPERATION_BINARY_DIV_TYPE);
00092               switch (root_node.op.type)
00093               {
00094                 case OPERATION_BINARY_ASSIGN_TYPE:
00095                   detail::ambm(u,
00096                                v, alpha, 1, is_division, false,
00097                                w,   1.0, 1, false,       flip_sign_z);
00098                   break;
00099                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00100                   detail::ambm_m(u,
00101                                  v, alpha, 1, is_division, false,
00102                                  w,   1.0, 1, false,       flip_sign_z);
00103                   break;
00104                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00105                   detail::ambm_m(u,
00106                                  v, alpha, 1, is_division, true,
00107                                  w,   1.0, 1, false,       !flip_sign_z);
00108                   break;
00109                 default:
00110                   throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)");
00111               }
00112             }
00113             else // no built-in kernel, we use a temporary.
00114             {
00115               statement_node new_root_y;
00116 
00117               new_root_y.lhs.type_family = root_node.lhs.type_family;
00118               new_root_y.lhs.type        = root_node.lhs.type;
00119               detail::new_matrix(new_root_y.lhs, (root_node.lhs.matrix_row_float)->size1(),  (root_node.lhs.matrix_row_float)->size2());
00120 
00121               new_root_y.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00122               new_root_y.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00123 
00124               new_root_y.rhs.type_family = COMPOSITE_OPERATION_FAMILY;
00125               new_root_y.rhs.type        = COMPOSITE_OPERATION_TYPE;
00126               new_root_y.rhs.node_index  = leaf.lhs.node_index;
00127 
00128               // work on subexpression:
00129               // TODO: Catch exception, free temporary, then rethrow
00130               execute_matrix(s, new_root_y);
00131 
00132               // now add:
00133               lhs_rhs_element u = root_node.lhs;
00134               lhs_rhs_element v = new_root_y.lhs;
00135               lhs_rhs_element w = leaf.rhs;
00136               switch (root_node.op.type)
00137               {
00138                 case OPERATION_BINARY_ASSIGN_TYPE:
00139                   detail::ambm(u,
00140                                v, 1.0, 1, false, false,
00141                                w, 1.0, 1, false, flip_sign_z);
00142                   break;
00143                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00144                   detail::ambm_m(u,
00145                                  v, 1.0, 1, false, false,
00146                                  w, 1.0, 1, false, flip_sign_z);
00147                   break;
00148                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00149                   detail::ambm_m(u,
00150                                  v, 1.0, 1, false, true,
00151                                  w, 1.0, 1, false, !flip_sign_z);
00152                   break;
00153                 default:
00154                   throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)");
00155               }
00156 
00157               detail::delete_matrix(new_root_y.lhs);
00158             }
00159           }
00160           else
00161             throw statement_not_supported_exception("Cannot deal with unary operations on matrices");
00162 
00163         }
00164         else if (  (leaf.lhs.type_family == MATRIX_COL_TYPE_FAMILY || leaf.lhs.type_family == MATRIX_ROW_TYPE_FAMILY)
00165                 && leaf.rhs.type_family == COMPOSITE_OPERATION_FAMILY) // x = y + (z), y being a matrix, z being a subtree itself
00166         {
00167           statement_node const & z = expr[leaf.rhs.node_index];
00168 
00169           if (z.op.type_family == OPERATION_BINARY_TYPE_FAMILY)
00170           {
00171             // z might be  'A * alpha' or 'A / alpha' with matrix A
00172             if (   (z.op.type == OPERATION_BINARY_MULT_TYPE     || z.op.type == OPERATION_BINARY_DIV_TYPE)
00173                 && (z.lhs.type_family == MATRIX_COL_TYPE_FAMILY || z.lhs.type_family == MATRIX_ROW_TYPE_FAMILY)
00174                 && (z.rhs.type_family == SCALAR_TYPE_FAMILY     || z.rhs.type_family == HOST_SCALAR_TYPE_FAMILY))
00175             {
00176               lhs_rhs_element u = root_node.lhs;
00177               lhs_rhs_element v = leaf.rhs;
00178               lhs_rhs_element w = z.lhs;
00179               lhs_rhs_element beta = z.rhs;
00180 
00181               bool is_division = (z.op.type == OPERATION_BINARY_DIV_TYPE);
00182               switch (root_node.op.type)
00183               {
00184                 case OPERATION_BINARY_ASSIGN_TYPE:
00185                   detail::ambm(u,
00186                                v,  1.0, 1, false, false,
00187                                w, beta, 1, is_division, flip_sign_z);
00188                   break;
00189                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00190                   detail::ambm_m(u,
00191                                  v,  1.0, 1, false, false,
00192                                  w, beta, 1, is_division, flip_sign_z);
00193                   break;
00194                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00195                   detail::ambm_m(u,
00196                                  v,  1.0, 1, false, true,
00197                                  w, beta, 1, is_division, !flip_sign_z);
00198                   break;
00199                 default:
00200                   throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)");
00201               }
00202             }
00203             else // no built-in kernel, we use a temporary.
00204             {
00205               statement_node new_root_z;
00206 
00207               new_root_z.lhs.type_family = root_node.lhs.type_family;
00208               new_root_z.lhs.type        = root_node.lhs.type;
00209               detail::new_matrix(new_root_z.lhs, (root_node.lhs.matrix_col_float)->size1(), (root_node.lhs.matrix_col_float)->size2());
00210 
00211               new_root_z.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00212               new_root_z.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00213 
00214               new_root_z.rhs.type_family = COMPOSITE_OPERATION_FAMILY;
00215               new_root_z.rhs.type        = COMPOSITE_OPERATION_TYPE;
00216               new_root_z.rhs.node_index  = leaf.rhs.node_index;
00217 
00218               // work on subexpression:
00219               // TODO: Catch exception, free temporary, then rethrow
00220               execute_matrix(s, new_root_z);
00221 
00222               // now add:
00223               lhs_rhs_element u = root_node.lhs;
00224               lhs_rhs_element v = leaf.lhs;
00225               lhs_rhs_element w = new_root_z.lhs;
00226               switch (root_node.op.type)
00227               {
00228                 case OPERATION_BINARY_ASSIGN_TYPE:
00229                   detail::ambm(u,
00230                                v, 1.0, 1, false, false,
00231                                w, 1.0, 1, false, flip_sign_z);
00232                   break;
00233                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00234                   detail::ambm_m(u,
00235                                  v, 1.0, 1, false, false,
00236                                  w, 1.0, 1, false, flip_sign_z);
00237                   break;
00238                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00239                   detail::ambm_m(u,
00240                                  v, 1.0, 1, false, true,
00241                                  w, 1.0, 1, false, !flip_sign_z);
00242                   break;
00243                 default:
00244                   throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)");
00245               }
00246 
00247               detail::delete_matrix(new_root_z.lhs);
00248             }
00249           }
00250           else
00251             throw statement_not_supported_exception("Cannot deal with unary operations on matrices");
00252 
00253         }
00254         else if (  leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY
00255                 && leaf.rhs.type_family == COMPOSITE_OPERATION_FAMILY) // x = (y) + (z), y and z being subtrees
00256         {
00257           statement_node const & y = expr[leaf.lhs.node_index];
00258           statement_node const & z = expr[leaf.rhs.node_index];
00259 
00260           if (   y.op.type_family == OPERATION_BINARY_TYPE_FAMILY
00261               && z.op.type_family == OPERATION_BINARY_TYPE_FAMILY)
00262           {
00263             // z might be  'A * alpha' or 'A / alpha' with matrix A
00264             if (   (y.op.type == OPERATION_BINARY_MULT_TYPE     || y.op.type == OPERATION_BINARY_DIV_TYPE)
00265                 && (y.lhs.type_family == MATRIX_COL_TYPE_FAMILY || y.lhs.type_family == MATRIX_ROW_TYPE_FAMILY)
00266                 && (y.rhs.type_family == SCALAR_TYPE_FAMILY     || y.rhs.type_family == HOST_SCALAR_TYPE_FAMILY)
00267                 && (z.op.type == OPERATION_BINARY_MULT_TYPE     || z.op.type == OPERATION_BINARY_DIV_TYPE)
00268                 && (z.lhs.type_family == MATRIX_COL_TYPE_FAMILY || z.lhs.type_family == MATRIX_ROW_TYPE_FAMILY)
00269                 && (z.rhs.type_family == SCALAR_TYPE_FAMILY     || z.rhs.type_family == HOST_SCALAR_TYPE_FAMILY))
00270             {
00271               lhs_rhs_element u = root_node.lhs;
00272               lhs_rhs_element v = y.lhs;
00273               lhs_rhs_element w = z.lhs;
00274               lhs_rhs_element alpha = y.rhs;
00275               lhs_rhs_element beta  = z.rhs;
00276 
00277               bool is_division_y = (y.op.type == OPERATION_BINARY_DIV_TYPE);
00278               bool is_division_z = (z.op.type == OPERATION_BINARY_DIV_TYPE);
00279               switch (root_node.op.type)
00280               {
00281                 case OPERATION_BINARY_ASSIGN_TYPE:
00282                   detail::ambm(u,
00283                                v, alpha, 1, is_division_y, false,
00284                                w,  beta, 1, is_division_z, flip_sign_z);
00285                   break;
00286                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00287                   detail::ambm_m(u,
00288                                  v, alpha, 1, is_division_y, false,
00289                                  w,  beta, 1, is_division_z, flip_sign_z);
00290                   break;
00291                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00292                   detail::ambm_m(u,
00293                                  v, alpha, 1, is_division_y, true,
00294                                  w,  beta, 1, is_division_z, !flip_sign_z);
00295                   break;
00296                 default:
00297                   throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)");
00298               }
00299             }
00300             else // no built-in kernel, we use a temporary.
00301             {
00302               statement_node new_root_y;
00303 
00304               new_root_y.lhs.type_family = root_node.lhs.type_family;
00305               new_root_y.lhs.type        = root_node.lhs.type;
00306               detail::new_matrix(new_root_y.lhs, (root_node.lhs.matrix_col_float)->size1(), (root_node.lhs.matrix_col_float)->size2());
00307 
00308               new_root_y.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00309               new_root_y.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00310 
00311               new_root_y.rhs.type_family = COMPOSITE_OPERATION_FAMILY;
00312               new_root_y.rhs.type        = COMPOSITE_OPERATION_TYPE;
00313               new_root_y.rhs.node_index  = leaf.lhs.node_index;
00314 
00315               // work on subexpression:
00316               // TODO: Catch exception, free temporary, then rethrow
00317               execute_matrix(s, new_root_y);
00318 
00319               statement_node new_root_z;
00320 
00321               new_root_z.lhs.type_family = root_node.lhs.type_family;
00322               new_root_z.lhs.type        = root_node.lhs.type;
00323               detail::new_matrix(new_root_z.lhs, (root_node.lhs.matrix_col_float)->size1(), (root_node.lhs.matrix_col_float)->size2());
00324 
00325               new_root_z.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00326               new_root_z.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00327 
00328               new_root_z.rhs.type_family = COMPOSITE_OPERATION_FAMILY;
00329               new_root_z.rhs.type        = COMPOSITE_OPERATION_TYPE;
00330               new_root_z.rhs.node_index  = leaf.rhs.node_index;
00331 
00332               // work on subexpression:
00333               // TODO: Catch exception, free temporaries, then rethrow
00334               execute_matrix(s, new_root_z);
00335 
00336               // now add:
00337               lhs_rhs_element u = root_node.lhs;
00338               lhs_rhs_element v = new_root_y.lhs;
00339               lhs_rhs_element w = new_root_z.lhs;
00340 
00341               switch (root_node.op.type)
00342               {
00343                 case OPERATION_BINARY_ASSIGN_TYPE:
00344                   detail::ambm(u,
00345                                v, 1.0, 1, false, false,
00346                                w, 1.0, 1, false, flip_sign_z);
00347                   break;
00348                 case OPERATION_BINARY_INPLACE_ADD_TYPE:
00349                   detail::ambm_m(u,
00350                                  v, 1.0, 1, false, false,
00351                                  w, 1.0, 1, false, flip_sign_z);
00352                   break;
00353                 case OPERATION_BINARY_INPLACE_SUB_TYPE:
00354                   detail::ambm_m(u,
00355                                  v, 1.0, 1, false, true,
00356                                  w, 1.0, 1, false, !flip_sign_z);
00357                   break;
00358                 default:
00359                   throw statement_not_supported_exception("Unsupported binary operator for matrices operation in root note (should be =, +=, or -=)");
00360               }
00361 
00362               detail::delete_matrix(new_root_y.lhs);
00363               detail::delete_matrix(new_root_z.lhs);
00364             }
00365           }
00366           else
00367             throw statement_not_supported_exception("Cannot deal with unary operations on matrices");
00368         }
00369         else
00370           throw statement_not_supported_exception("Cannot deal with addition of matrices");
00371       }
00372       else if (leaf.op.type  == OPERATION_BINARY_MULT_TYPE || leaf.op.type  == OPERATION_BINARY_DIV_TYPE) // x = y * / alpha;
00373       {
00374         if (   (leaf.lhs.type_family == MATRIX_COL_TYPE_FAMILY || leaf.lhs.type_family == MATRIX_ROW_TYPE_FAMILY)
00375             && (leaf.rhs.type_family == SCALAR_TYPE_FAMILY || leaf.rhs.type_family == HOST_SCALAR_TYPE_FAMILY))
00376         {
00377           lhs_rhs_element u = root_node.lhs;
00378           lhs_rhs_element v = leaf.lhs;
00379           lhs_rhs_element alpha = leaf.rhs;
00380 
00381           bool is_division = (leaf.op.type  == OPERATION_BINARY_DIV_TYPE);
00382           switch (root_node.op.type)
00383           {
00384             case OPERATION_BINARY_ASSIGN_TYPE:
00385               detail::am(u,
00386                          v, alpha, 1, is_division, false);
00387               break;
00388             case OPERATION_BINARY_INPLACE_ADD_TYPE:
00389               detail::ambm(u,
00390                            u,   1.0, 1, false,       false,
00391                            v, alpha, 1, is_division, false);
00392               break;
00393             case OPERATION_BINARY_INPLACE_SUB_TYPE:
00394               detail::ambm(u,
00395                            u,   1.0, 1, false,       false,
00396                            v, alpha, 1, is_division, true);
00397               break;
00398             default:
00399               throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)");
00400           }
00401 
00402         }
00403         else
00404           throw statement_not_supported_exception("Unsupported binary operator for OPERATION_BINARY_MULT_TYPE || OPERATION_BINARY_DIV_TYPE on leaf node.");
00405       }
00406       else
00407         throw statement_not_supported_exception("Unsupported binary operator for matrix operations");
00408     }
00409 
00411     inline void execute_matrix_matrix(statement const & /*s*/, statement_node const & root_node)
00412     {
00413       lhs_rhs_element u = root_node.lhs;
00414       lhs_rhs_element v = root_node.rhs;
00415       switch (root_node.op.type)
00416       {
00417         case OPERATION_BINARY_ASSIGN_TYPE:
00418           detail::am(u,
00419                      v, 1.0, 1, false, false);
00420           break;
00421         case OPERATION_BINARY_INPLACE_ADD_TYPE:
00422           detail::ambm(u,
00423                        u, 1.0, 1, false, false,
00424                        v, 1.0, 1, false, false);
00425           break;
00426         case OPERATION_BINARY_INPLACE_SUB_TYPE:
00427           detail::ambm(u,
00428                        u, 1.0, 1, false, false,
00429                        v, 1.0, 1, false, true);
00430           break;
00431         default:
00432           throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)");
00433       }
00434     }
00435 
00437     inline void execute_matrix(statement const & s, statement_node const & root_node)
00438     {
00439       switch (root_node.rhs.type_family)
00440       {
00441         case COMPOSITE_OPERATION_FAMILY:
00442           execute_matrix_composite(s, root_node);
00443           break;
00444         case MATRIX_COL_TYPE_FAMILY:
00445         case MATRIX_ROW_TYPE_FAMILY:
00446           execute_matrix_matrix(s, root_node);
00447           break;
00448         default:
00449           throw statement_not_supported_exception("Invalid rvalue encountered in matrix assignment");
00450       }
00451     }
00452 
00453 
00454   }
00455 
00456 } //namespace viennacl
00457 
00458 #endif
00459