ViennaCL - The Vienna Computing Library  1.5.0
viennacl/linalg/opencl/kernels/hyb_matrix.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_HYB_MATRIX_HPP
00002 #define VIENNACL_LINALG_OPENCL_KERNELS_HYB_MATRIX_HPP
00003 
00004 #include "viennacl/tools/tools.hpp"
00005 #include "viennacl/ocl/kernel.hpp"
00006 #include "viennacl/ocl/platform.hpp"
00007 #include "viennacl/ocl/utils.hpp"
00008 
00009 #include "viennacl/linalg/opencl/common.hpp"
00010 
00013 namespace viennacl
00014 {
00015   namespace linalg
00016   {
00017     namespace opencl
00018     {
00019       namespace kernels
00020       {
00021 
00023 
00024         template <typename StringType>
00025         void generate_hyb_vec_mul(StringType & source, std::string const & numeric_string)
00026         {
00027           source.append("__kernel void vec_mul( \n");
00028           source.append("  const __global int* ell_coords, \n");
00029           source.append("  const __global "); source.append(numeric_string); source.append("* ell_elements, \n");
00030           source.append("  const __global uint* csr_rows, \n");
00031           source.append("  const __global uint* csr_cols, \n");
00032           source.append("  const __global "); source.append(numeric_string); source.append("* csr_elements, \n");
00033           source.append("  const __global "); source.append(numeric_string); source.append(" * x, \n");
00034           source.append("  uint4 layout_x, \n");
00035           source.append("  __global "); source.append(numeric_string); source.append(" * result, \n");
00036           source.append("  uint4 layout_result, \n");
00037           source.append("  unsigned int row_num, \n");
00038           source.append("  unsigned int internal_row_num, \n");
00039           source.append("  unsigned int items_per_row, \n");
00040           source.append("  unsigned int aligned_items_per_row) \n");
00041           source.append("{ \n");
00042           source.append("  uint glb_id = get_global_id(0); \n");
00043           source.append("  uint glb_sz = get_global_size(0); \n");
00044 
00045           source.append("  for(uint row_id = glb_id; row_id < row_num; row_id += glb_sz) { \n");
00046           source.append("    "); source.append(numeric_string); source.append(" sum = 0; \n");
00047 
00048           source.append("    uint offset = row_id; \n");
00049           source.append("    for(uint item_id = 0; item_id < items_per_row; item_id++, offset += internal_row_num) { \n");
00050           source.append("      "); source.append(numeric_string); source.append(" val = ell_elements[offset]; \n");
00051 
00052           source.append("      if(val != ("); source.append(numeric_string); source.append(")0) { \n");
00053           source.append("        int col = ell_coords[offset]; \n");
00054           source.append("        sum += (x[col * layout_x.y + layout_x.x] * val); \n");
00055           source.append("      } \n");
00056 
00057           source.append("    } \n");
00058 
00059           source.append("    uint col_begin = csr_rows[row_id]; \n");
00060           source.append("    uint col_end   = csr_rows[row_id + 1]; \n");
00061 
00062           source.append("    for(uint item_id = col_begin; item_id < col_end; item_id++) {  \n");
00063           source.append("      sum += (x[csr_cols[item_id] * layout_x.y + layout_x.x] * csr_elements[item_id]); \n");
00064           source.append("    } \n");
00065 
00066           source.append("    result[row_id * layout_result.y + layout_result.x] = sum; \n");
00067           source.append("  } \n");
00068           source.append("} \n");
00069         }
00070 
00071         namespace detail
00072         {
00073           template <typename StringType>
00074           void generate_hyb_matrix_dense_matrix_mul(StringType & source, std::string const & numeric_string,
00075                                                     bool B_transposed, bool B_row_major, bool C_row_major)
00076           {
00077             source.append("__kernel void ");
00078             source.append(viennacl::linalg::opencl::detail::sparse_dense_matmult_kernel_name(B_transposed, B_row_major, C_row_major));
00079             source.append("( \n");
00080             source.append("  const __global int* ell_coords, \n");
00081             source.append("  const __global "); source.append(numeric_string); source.append("* ell_elements, \n");
00082             source.append("  const __global uint* csr_rows, \n");
00083             source.append("  const __global uint* csr_cols, \n");
00084             source.append("  const __global "); source.append(numeric_string); source.append("* csr_elements, \n");
00085             source.append("  unsigned int row_num, \n");
00086             source.append("  unsigned int internal_row_num, \n");
00087             source.append("  unsigned int items_per_row, \n");
00088             source.append("  unsigned int aligned_items_per_row, \n");
00089             source.append("    __global const "); source.append(numeric_string); source.append("* d_mat, \n");
00090             source.append("    unsigned int d_mat_row_start, \n");
00091             source.append("    unsigned int d_mat_col_start, \n");
00092             source.append("    unsigned int d_mat_row_inc, \n");
00093             source.append("    unsigned int d_mat_col_inc, \n");
00094             source.append("    unsigned int d_mat_row_size, \n");
00095             source.append("    unsigned int d_mat_col_size, \n");
00096             source.append("    unsigned int d_mat_internal_rows, \n");
00097             source.append("    unsigned int d_mat_internal_cols, \n");
00098             source.append("    __global "); source.append(numeric_string); source.append(" * result, \n");
00099             source.append("    unsigned int result_row_start, \n");
00100             source.append("    unsigned int result_col_start, \n");
00101             source.append("    unsigned int result_row_inc, \n");
00102             source.append("    unsigned int result_col_inc, \n");
00103             source.append("    unsigned int result_row_size, \n");
00104             source.append("    unsigned int result_col_size, \n");
00105             source.append("    unsigned int result_internal_rows, \n");
00106             source.append("    unsigned int result_internal_cols) { \n");
00107 
00108             source.append("  uint glb_id = get_global_id(0); \n");
00109             source.append("  uint glb_sz = get_global_size(0); \n");
00110 
00111             source.append("  for(uint result_col = 0; result_col < result_col_size; ++result_col) { \n");
00112             source.append("   for(uint row_id = glb_id; row_id < row_num; row_id += glb_sz) { \n");
00113             source.append("    "); source.append(numeric_string); source.append(" sum = 0; \n");
00114 
00115             source.append("    uint offset = row_id; \n");
00116             source.append("    for(uint item_id = 0; item_id < items_per_row; item_id++, offset += internal_row_num) { \n");
00117             source.append("      "); source.append(numeric_string); source.append(" val = ell_elements[offset]; \n");
00118 
00119             source.append("      if(val != ("); source.append(numeric_string); source.append(")0) { \n");
00120             source.append("        int col = ell_coords[offset]; \n");
00121             if (B_transposed && B_row_major)
00122               source.append("      sum += d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) * d_mat_internal_cols +  d_mat_col_start +        col * d_mat_col_inc                        ] * val; \n");
00123             else if (B_transposed && !B_row_major)
00124               source.append("      sum += d_mat[ (d_mat_row_start + result_col * d_mat_row_inc)                       + (d_mat_col_start +        col * d_mat_col_inc) * d_mat_internal_rows ] * val; \n");
00125             else if (!B_transposed && B_row_major)
00126               source.append("      sum += d_mat[ (d_mat_row_start +        col * d_mat_row_inc) * d_mat_internal_cols +  d_mat_col_start + result_col * d_mat_col_inc                        ] * val; \n");
00127             else
00128               source.append("      sum += d_mat[ (d_mat_row_start +        col * d_mat_row_inc)                       + (d_mat_col_start + result_col * d_mat_col_inc) * d_mat_internal_rows ] * val; \n");
00129             source.append("      } \n");
00130 
00131             source.append("    } \n");
00132 
00133             source.append("    uint col_begin = csr_rows[row_id]; \n");
00134             source.append("    uint col_end   = csr_rows[row_id + 1]; \n");
00135 
00136             source.append("    for(uint item_id = col_begin; item_id < col_end; item_id++) {  \n");
00137             if (B_transposed && B_row_major)
00138               source.append("      sum += d_mat[ (d_mat_row_start +        result_col * d_mat_row_inc) * d_mat_internal_cols +  d_mat_col_start + csr_cols[item_id] * d_mat_col_inc                        ] * csr_elements[item_id]; \n");
00139             else if (B_transposed && !B_row_major)
00140               source.append("      sum += d_mat[ (d_mat_row_start +        result_col * d_mat_row_inc)                       + (d_mat_col_start + csr_cols[item_id] * d_mat_col_inc) * d_mat_internal_rows ] * csr_elements[item_id]; \n");
00141             else if (!B_transposed && B_row_major)
00142               source.append("      sum += d_mat[ (d_mat_row_start + csr_cols[item_id] * d_mat_row_inc) * d_mat_internal_cols +  d_mat_col_start +        result_col * d_mat_col_inc                        ] * csr_elements[item_id]; \n");
00143             else
00144               source.append("      sum += d_mat[ (d_mat_row_start + csr_cols[item_id] * d_mat_row_inc)                       + (d_mat_col_start +        result_col * d_mat_col_inc) * d_mat_internal_rows ] * csr_elements[item_id]; \n");
00145             source.append("    } \n");
00146 
00147             if (C_row_major)
00148               source.append("      result[ (result_row_start + row_id * result_row_inc) * result_internal_cols + result_col_start + result_col * result_col_inc ] = sum; \n");
00149             else
00150               source.append("      result[ (result_row_start + row_id * result_row_inc)                        + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = sum; \n");
00151             source.append("   } \n");
00152             source.append("  } \n");
00153             source.append("} \n");
00154           }
00155         }
00156 
00157         template <typename StringType>
00158         void generate_hyb_matrix_dense_matrix_multiplication(StringType & source, std::string const & numeric_string)
00159         {
00160           detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, false, false, false);
00161           detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, false, false,  true);
00162           detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, false,  true, false);
00163           detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, false,  true,  true);
00164 
00165           detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, true, false, false);
00166           detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, true, false,  true);
00167           detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, true,  true, false);
00168           detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, true,  true,  true);
00169         }
00170 
00172 
00173         // main kernel class
00175         template <typename NumericT>
00176         struct hyb_matrix
00177         {
00178           static std::string program_name()
00179           {
00180             return viennacl::ocl::type_to_string<NumericT>::apply() + "_hyb_matrix";
00181           }
00182 
00183           static void init(viennacl::ocl::context & ctx)
00184           {
00185             viennacl::ocl::DOUBLE_PRECISION_CHECKER<NumericT>::apply(ctx);
00186             std::string numeric_string = viennacl::ocl::type_to_string<NumericT>::apply();
00187 
00188             static std::map<cl_context, bool> init_done;
00189             if (!init_done[ctx.handle().get()])
00190             {
00191               std::string source;
00192               source.reserve(1024);
00193 
00194               viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
00195 
00196               generate_hyb_vec_mul(source, numeric_string);
00197               generate_hyb_matrix_dense_matrix_multiplication(source, numeric_string);
00198 
00199               std::string prog_name = program_name();
00200               #ifdef VIENNACL_BUILD_INFO
00201               std::cout << "Creating program " << prog_name << std::endl;
00202               #endif
00203               ctx.add_program(source, prog_name);
00204               init_done[ctx.handle().get()] = true;
00205             } //if
00206           } //init
00207         };
00208 
00209       }  // namespace kernels
00210     }  // namespace opencl
00211   }  // namespace linalg
00212 }  // namespace viennacl
00213 #endif
00214