ViennaCL - The Vienna Computing Library
1.5.0
|
00001 #ifndef VIENNACL_SCHEDULER_EXECUTE_MATRIX_HPP 00002 #define VIENNACL_SCHEDULER_EXECUTE_MATRIX_HPP 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2013, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00021 00026 #include "viennacl/forwards.h" 00027 #include "viennacl/scheduler/forwards.h" 00028 #include "viennacl/scheduler/execute_matrix_dispatcher.hpp" 00029 00030 namespace viennacl 00031 { 00032 namespace scheduler 00033 { 00034 inline void execute_matrix(statement const & s, statement_node const & root_node); 00035 00036 00038 inline void execute_matrix_composite(statement const & s, statement_node const & root_node) 00039 { 00040 statement::container_type const & expr = s.array(); 00041 00042 statement_node const & leaf = expr[root_node.rhs.node_index]; 00043 00044 if (leaf.op.type == OPERATION_BINARY_ADD_TYPE || leaf.op.type == OPERATION_BINARY_SUB_TYPE) // x = (y) +- (z) where y and z are either matrices or expressions 00045 { 00046 bool flip_sign_z = (leaf.op.type == OPERATION_BINARY_SUB_TYPE); 00047 00048 if ( leaf.lhs.type_family == MATRIX_ROW_TYPE_FAMILY || leaf.lhs.type_family == MATRIX_COL_TYPE_FAMILY ) 00049 { 00050 lhs_rhs_element u = root_node.lhs; 00051 lhs_rhs_element v = leaf.lhs; 00052 lhs_rhs_element w = leaf.rhs; 00053 switch (root_node.op.type) 00054 { 00055 case OPERATION_BINARY_ASSIGN_TYPE: 00056 detail::ambm(u, 00057 v, 1.0, 1, false, false, 00058 w, 1.0, 1, false, flip_sign_z); 00059 break; 00060 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00061 detail::ambm_m(u, 00062 v, 1.0, 1, false, false, 00063 w, 1.0, 1, false, flip_sign_z); 00064 break; 00065 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00066 detail::ambm_m(u, 00067 v, 1.0, 1, false, true, 00068 w, 1.0, 1, false, !flip_sign_z); 00069 break; 00070 default: 00071 throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)"); 00072 } 00073 } 00074 else if ( leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY 00075 && (leaf.rhs.type_family == MATRIX_COL_TYPE_FAMILY || leaf.rhs.type_family == MATRIX_ROW_TYPE_FAMILY) ) // x = (y) + z, y being a subtree itself, z being a matrix 00076 { 00077 statement_node const & y = expr[leaf.lhs.node_index]; 00078 00079 if (y.op.type_family == OPERATION_BINARY_TYPE_FAMILY) 00080 { 00081 // y might be 'A * alpha' or 'A / alpha' with matrix A 00082 if ( (y.op.type == OPERATION_BINARY_MULT_TYPE || y.op.type == OPERATION_BINARY_DIV_TYPE) 00083 && (y.lhs.type_family == MATRIX_ROW_TYPE_FAMILY || y.lhs.type_family == MATRIX_COL_TYPE_FAMILY) 00084 && (y.rhs.type_family == SCALAR_TYPE_FAMILY || y.rhs.type_family == HOST_SCALAR_TYPE_FAMILY)) 00085 { 00086 lhs_rhs_element u = root_node.lhs; 00087 lhs_rhs_element v = y.lhs; 00088 lhs_rhs_element w = leaf.rhs; 00089 lhs_rhs_element alpha = y.rhs; 00090 00091 bool is_division = (y.op.type == OPERATION_BINARY_DIV_TYPE); 00092 switch (root_node.op.type) 00093 { 00094 case OPERATION_BINARY_ASSIGN_TYPE: 00095 detail::ambm(u, 00096 v, alpha, 1, is_division, false, 00097 w, 1.0, 1, false, flip_sign_z); 00098 break; 00099 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00100 detail::ambm_m(u, 00101 v, alpha, 1, is_division, false, 00102 w, 1.0, 1, false, flip_sign_z); 00103 break; 00104 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00105 detail::ambm_m(u, 00106 v, alpha, 1, is_division, true, 00107 w, 1.0, 1, false, !flip_sign_z); 00108 break; 00109 default: 00110 throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)"); 00111 } 00112 } 00113 else // no built-in kernel, we use a temporary. 00114 { 00115 statement_node new_root_y; 00116 00117 new_root_y.lhs.type_family = root_node.lhs.type_family; 00118 new_root_y.lhs.type = root_node.lhs.type; 00119 detail::new_matrix(new_root_y.lhs, (root_node.lhs.matrix_row_float)->size1(), (root_node.lhs.matrix_row_float)->size2()); 00120 00121 new_root_y.op.type_family = OPERATION_BINARY_TYPE_FAMILY; 00122 new_root_y.op.type = OPERATION_BINARY_ASSIGN_TYPE; 00123 00124 new_root_y.rhs.type_family = COMPOSITE_OPERATION_FAMILY; 00125 new_root_y.rhs.type = COMPOSITE_OPERATION_TYPE; 00126 new_root_y.rhs.node_index = leaf.lhs.node_index; 00127 00128 // work on subexpression: 00129 // TODO: Catch exception, free temporary, then rethrow 00130 execute_matrix(s, new_root_y); 00131 00132 // now add: 00133 lhs_rhs_element u = root_node.lhs; 00134 lhs_rhs_element v = new_root_y.lhs; 00135 lhs_rhs_element w = leaf.rhs; 00136 switch (root_node.op.type) 00137 { 00138 case OPERATION_BINARY_ASSIGN_TYPE: 00139 detail::ambm(u, 00140 v, 1.0, 1, false, false, 00141 w, 1.0, 1, false, flip_sign_z); 00142 break; 00143 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00144 detail::ambm_m(u, 00145 v, 1.0, 1, false, false, 00146 w, 1.0, 1, false, flip_sign_z); 00147 break; 00148 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00149 detail::ambm_m(u, 00150 v, 1.0, 1, false, true, 00151 w, 1.0, 1, false, !flip_sign_z); 00152 break; 00153 default: 00154 throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)"); 00155 } 00156 00157 detail::delete_matrix(new_root_y.lhs); 00158 } 00159 } 00160 else 00161 throw statement_not_supported_exception("Cannot deal with unary operations on matrices"); 00162 00163 } 00164 else if ( (leaf.lhs.type_family == MATRIX_COL_TYPE_FAMILY || leaf.lhs.type_family == MATRIX_ROW_TYPE_FAMILY) 00165 && leaf.rhs.type_family == COMPOSITE_OPERATION_FAMILY) // x = y + (z), y being a matrix, z being a subtree itself 00166 { 00167 statement_node const & z = expr[leaf.rhs.node_index]; 00168 00169 if (z.op.type_family == OPERATION_BINARY_TYPE_FAMILY) 00170 { 00171 // z might be 'A * alpha' or 'A / alpha' with matrix A 00172 if ( (z.op.type == OPERATION_BINARY_MULT_TYPE || z.op.type == OPERATION_BINARY_DIV_TYPE) 00173 && (z.lhs.type_family == MATRIX_COL_TYPE_FAMILY || z.lhs.type_family == MATRIX_ROW_TYPE_FAMILY) 00174 && (z.rhs.type_family == SCALAR_TYPE_FAMILY || z.rhs.type_family == HOST_SCALAR_TYPE_FAMILY)) 00175 { 00176 lhs_rhs_element u = root_node.lhs; 00177 lhs_rhs_element v = leaf.rhs; 00178 lhs_rhs_element w = z.lhs; 00179 lhs_rhs_element beta = z.rhs; 00180 00181 bool is_division = (z.op.type == OPERATION_BINARY_DIV_TYPE); 00182 switch (root_node.op.type) 00183 { 00184 case OPERATION_BINARY_ASSIGN_TYPE: 00185 detail::ambm(u, 00186 v, 1.0, 1, false, false, 00187 w, beta, 1, is_division, flip_sign_z); 00188 break; 00189 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00190 detail::ambm_m(u, 00191 v, 1.0, 1, false, false, 00192 w, beta, 1, is_division, flip_sign_z); 00193 break; 00194 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00195 detail::ambm_m(u, 00196 v, 1.0, 1, false, true, 00197 w, beta, 1, is_division, !flip_sign_z); 00198 break; 00199 default: 00200 throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)"); 00201 } 00202 } 00203 else // no built-in kernel, we use a temporary. 00204 { 00205 statement_node new_root_z; 00206 00207 new_root_z.lhs.type_family = root_node.lhs.type_family; 00208 new_root_z.lhs.type = root_node.lhs.type; 00209 detail::new_matrix(new_root_z.lhs, (root_node.lhs.matrix_col_float)->size1(), (root_node.lhs.matrix_col_float)->size2()); 00210 00211 new_root_z.op.type_family = OPERATION_BINARY_TYPE_FAMILY; 00212 new_root_z.op.type = OPERATION_BINARY_ASSIGN_TYPE; 00213 00214 new_root_z.rhs.type_family = COMPOSITE_OPERATION_FAMILY; 00215 new_root_z.rhs.type = COMPOSITE_OPERATION_TYPE; 00216 new_root_z.rhs.node_index = leaf.rhs.node_index; 00217 00218 // work on subexpression: 00219 // TODO: Catch exception, free temporary, then rethrow 00220 execute_matrix(s, new_root_z); 00221 00222 // now add: 00223 lhs_rhs_element u = root_node.lhs; 00224 lhs_rhs_element v = leaf.lhs; 00225 lhs_rhs_element w = new_root_z.lhs; 00226 switch (root_node.op.type) 00227 { 00228 case OPERATION_BINARY_ASSIGN_TYPE: 00229 detail::ambm(u, 00230 v, 1.0, 1, false, false, 00231 w, 1.0, 1, false, flip_sign_z); 00232 break; 00233 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00234 detail::ambm_m(u, 00235 v, 1.0, 1, false, false, 00236 w, 1.0, 1, false, flip_sign_z); 00237 break; 00238 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00239 detail::ambm_m(u, 00240 v, 1.0, 1, false, true, 00241 w, 1.0, 1, false, !flip_sign_z); 00242 break; 00243 default: 00244 throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)"); 00245 } 00246 00247 detail::delete_matrix(new_root_z.lhs); 00248 } 00249 } 00250 else 00251 throw statement_not_supported_exception("Cannot deal with unary operations on matrices"); 00252 00253 } 00254 else if ( leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY 00255 && leaf.rhs.type_family == COMPOSITE_OPERATION_FAMILY) // x = (y) + (z), y and z being subtrees 00256 { 00257 statement_node const & y = expr[leaf.lhs.node_index]; 00258 statement_node const & z = expr[leaf.rhs.node_index]; 00259 00260 if ( y.op.type_family == OPERATION_BINARY_TYPE_FAMILY 00261 && z.op.type_family == OPERATION_BINARY_TYPE_FAMILY) 00262 { 00263 // z might be 'A * alpha' or 'A / alpha' with matrix A 00264 if ( (y.op.type == OPERATION_BINARY_MULT_TYPE || y.op.type == OPERATION_BINARY_DIV_TYPE) 00265 && (y.lhs.type_family == MATRIX_COL_TYPE_FAMILY || y.lhs.type_family == MATRIX_ROW_TYPE_FAMILY) 00266 && (y.rhs.type_family == SCALAR_TYPE_FAMILY || y.rhs.type_family == HOST_SCALAR_TYPE_FAMILY) 00267 && (z.op.type == OPERATION_BINARY_MULT_TYPE || z.op.type == OPERATION_BINARY_DIV_TYPE) 00268 && (z.lhs.type_family == MATRIX_COL_TYPE_FAMILY || z.lhs.type_family == MATRIX_ROW_TYPE_FAMILY) 00269 && (z.rhs.type_family == SCALAR_TYPE_FAMILY || z.rhs.type_family == HOST_SCALAR_TYPE_FAMILY)) 00270 { 00271 lhs_rhs_element u = root_node.lhs; 00272 lhs_rhs_element v = y.lhs; 00273 lhs_rhs_element w = z.lhs; 00274 lhs_rhs_element alpha = y.rhs; 00275 lhs_rhs_element beta = z.rhs; 00276 00277 bool is_division_y = (y.op.type == OPERATION_BINARY_DIV_TYPE); 00278 bool is_division_z = (z.op.type == OPERATION_BINARY_DIV_TYPE); 00279 switch (root_node.op.type) 00280 { 00281 case OPERATION_BINARY_ASSIGN_TYPE: 00282 detail::ambm(u, 00283 v, alpha, 1, is_division_y, false, 00284 w, beta, 1, is_division_z, flip_sign_z); 00285 break; 00286 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00287 detail::ambm_m(u, 00288 v, alpha, 1, is_division_y, false, 00289 w, beta, 1, is_division_z, flip_sign_z); 00290 break; 00291 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00292 detail::ambm_m(u, 00293 v, alpha, 1, is_division_y, true, 00294 w, beta, 1, is_division_z, !flip_sign_z); 00295 break; 00296 default: 00297 throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)"); 00298 } 00299 } 00300 else // no built-in kernel, we use a temporary. 00301 { 00302 statement_node new_root_y; 00303 00304 new_root_y.lhs.type_family = root_node.lhs.type_family; 00305 new_root_y.lhs.type = root_node.lhs.type; 00306 detail::new_matrix(new_root_y.lhs, (root_node.lhs.matrix_col_float)->size1(), (root_node.lhs.matrix_col_float)->size2()); 00307 00308 new_root_y.op.type_family = OPERATION_BINARY_TYPE_FAMILY; 00309 new_root_y.op.type = OPERATION_BINARY_ASSIGN_TYPE; 00310 00311 new_root_y.rhs.type_family = COMPOSITE_OPERATION_FAMILY; 00312 new_root_y.rhs.type = COMPOSITE_OPERATION_TYPE; 00313 new_root_y.rhs.node_index = leaf.lhs.node_index; 00314 00315 // work on subexpression: 00316 // TODO: Catch exception, free temporary, then rethrow 00317 execute_matrix(s, new_root_y); 00318 00319 statement_node new_root_z; 00320 00321 new_root_z.lhs.type_family = root_node.lhs.type_family; 00322 new_root_z.lhs.type = root_node.lhs.type; 00323 detail::new_matrix(new_root_z.lhs, (root_node.lhs.matrix_col_float)->size1(), (root_node.lhs.matrix_col_float)->size2()); 00324 00325 new_root_z.op.type_family = OPERATION_BINARY_TYPE_FAMILY; 00326 new_root_z.op.type = OPERATION_BINARY_ASSIGN_TYPE; 00327 00328 new_root_z.rhs.type_family = COMPOSITE_OPERATION_FAMILY; 00329 new_root_z.rhs.type = COMPOSITE_OPERATION_TYPE; 00330 new_root_z.rhs.node_index = leaf.rhs.node_index; 00331 00332 // work on subexpression: 00333 // TODO: Catch exception, free temporaries, then rethrow 00334 execute_matrix(s, new_root_z); 00335 00336 // now add: 00337 lhs_rhs_element u = root_node.lhs; 00338 lhs_rhs_element v = new_root_y.lhs; 00339 lhs_rhs_element w = new_root_z.lhs; 00340 00341 switch (root_node.op.type) 00342 { 00343 case OPERATION_BINARY_ASSIGN_TYPE: 00344 detail::ambm(u, 00345 v, 1.0, 1, false, false, 00346 w, 1.0, 1, false, flip_sign_z); 00347 break; 00348 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00349 detail::ambm_m(u, 00350 v, 1.0, 1, false, false, 00351 w, 1.0, 1, false, flip_sign_z); 00352 break; 00353 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00354 detail::ambm_m(u, 00355 v, 1.0, 1, false, true, 00356 w, 1.0, 1, false, !flip_sign_z); 00357 break; 00358 default: 00359 throw statement_not_supported_exception("Unsupported binary operator for matrices operation in root note (should be =, +=, or -=)"); 00360 } 00361 00362 detail::delete_matrix(new_root_y.lhs); 00363 detail::delete_matrix(new_root_z.lhs); 00364 } 00365 } 00366 else 00367 throw statement_not_supported_exception("Cannot deal with unary operations on matrices"); 00368 } 00369 else 00370 throw statement_not_supported_exception("Cannot deal with addition of matrices"); 00371 } 00372 else if (leaf.op.type == OPERATION_BINARY_MULT_TYPE || leaf.op.type == OPERATION_BINARY_DIV_TYPE) // x = y * / alpha; 00373 { 00374 if ( (leaf.lhs.type_family == MATRIX_COL_TYPE_FAMILY || leaf.lhs.type_family == MATRIX_ROW_TYPE_FAMILY) 00375 && (leaf.rhs.type_family == SCALAR_TYPE_FAMILY || leaf.rhs.type_family == HOST_SCALAR_TYPE_FAMILY)) 00376 { 00377 lhs_rhs_element u = root_node.lhs; 00378 lhs_rhs_element v = leaf.lhs; 00379 lhs_rhs_element alpha = leaf.rhs; 00380 00381 bool is_division = (leaf.op.type == OPERATION_BINARY_DIV_TYPE); 00382 switch (root_node.op.type) 00383 { 00384 case OPERATION_BINARY_ASSIGN_TYPE: 00385 detail::am(u, 00386 v, alpha, 1, is_division, false); 00387 break; 00388 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00389 detail::ambm(u, 00390 u, 1.0, 1, false, false, 00391 v, alpha, 1, is_division, false); 00392 break; 00393 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00394 detail::ambm(u, 00395 u, 1.0, 1, false, false, 00396 v, alpha, 1, is_division, true); 00397 break; 00398 default: 00399 throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)"); 00400 } 00401 00402 } 00403 else 00404 throw statement_not_supported_exception("Unsupported binary operator for OPERATION_BINARY_MULT_TYPE || OPERATION_BINARY_DIV_TYPE on leaf node."); 00405 } 00406 else 00407 throw statement_not_supported_exception("Unsupported binary operator for matrix operations"); 00408 } 00409 00411 inline void execute_matrix_matrix(statement const & /*s*/, statement_node const & root_node) 00412 { 00413 lhs_rhs_element u = root_node.lhs; 00414 lhs_rhs_element v = root_node.rhs; 00415 switch (root_node.op.type) 00416 { 00417 case OPERATION_BINARY_ASSIGN_TYPE: 00418 detail::am(u, 00419 v, 1.0, 1, false, false); 00420 break; 00421 case OPERATION_BINARY_INPLACE_ADD_TYPE: 00422 detail::ambm(u, 00423 u, 1.0, 1, false, false, 00424 v, 1.0, 1, false, false); 00425 break; 00426 case OPERATION_BINARY_INPLACE_SUB_TYPE: 00427 detail::ambm(u, 00428 u, 1.0, 1, false, false, 00429 v, 1.0, 1, false, true); 00430 break; 00431 default: 00432 throw statement_not_supported_exception("Unsupported binary operator for matrix operation in root note (should be =, +=, or -=)"); 00433 } 00434 } 00435 00437 inline void execute_matrix(statement const & s, statement_node const & root_node) 00438 { 00439 switch (root_node.rhs.type_family) 00440 { 00441 case COMPOSITE_OPERATION_FAMILY: 00442 execute_matrix_composite(s, root_node); 00443 break; 00444 case MATRIX_COL_TYPE_FAMILY: 00445 case MATRIX_ROW_TYPE_FAMILY: 00446 execute_matrix_matrix(s, root_node); 00447 break; 00448 default: 00449 throw statement_not_supported_exception("Invalid rvalue encountered in matrix assignment"); 00450 } 00451 } 00452 00453 00454 } 00455 00456 } //namespace viennacl 00457 00458 #endif 00459