ViennaCL - The Vienna Computing Library
1.5.0
|
00001 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_SOLVE_HPP 00002 #define VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_SOLVE_HPP 00003 00004 #include "viennacl/tools/tools.hpp" 00005 #include "viennacl/ocl/kernel.hpp" 00006 #include "viennacl/ocl/platform.hpp" 00007 #include "viennacl/ocl/utils.hpp" 00008 00009 #include "viennacl/linalg/opencl/kernels/matrix.hpp" 00010 00013 namespace viennacl 00014 { 00015 namespace linalg 00016 { 00017 namespace opencl 00018 { 00019 namespace kernels 00020 { 00021 00022 template <typename StringType> 00023 void generate_matrix_solve_blas3(StringType & source, std::string const & numeric_string, 00024 bool row_major_A, bool row_major_B, 00025 bool transpose_A, bool transpose_B, 00026 bool upper_solve, bool unit_diagonal) 00027 { 00028 //start OpenCL code: 00029 source.append("__kernel void "); 00030 if (transpose_A) 00031 source.append("trans_"); 00032 if (unit_diagonal) 00033 source.append("unit_"); 00034 if (upper_solve) 00035 source.append("upper_"); 00036 else 00037 source.append("lower_"); 00038 if (transpose_B) 00039 source.append("trans_"); 00040 source.append("solve"); 00041 00042 source.append("( \n"); 00043 source.append(" __global const "); source.append(numeric_string); source.append(" * A, \n"); 00044 source.append(" unsigned int A_start1, unsigned int A_start2, \n"); 00045 source.append(" unsigned int A_inc1, unsigned int A_inc2, \n"); 00046 source.append(" unsigned int A_size1, unsigned int A_size2, \n"); 00047 source.append(" unsigned int A_internal_size1, unsigned int A_internal_size2, \n"); 00048 source.append(" __global "); source.append(numeric_string); source.append(" * B, \n"); 00049 source.append(" unsigned int B_start1, unsigned int B_start2, \n"); 00050 source.append(" unsigned int B_inc1, unsigned int B_inc2, \n"); 00051 source.append(" unsigned int B_size1, unsigned int B_size2, \n"); 00052 source.append(" unsigned int B_internal_size1, unsigned int B_internal_size2) { \n"); 00053 source.append(" "); source.append(numeric_string); source.append(" temp; \n"); 00054 if (upper_solve) 00055 { 00056 //Note: A is square, thus A_rows == A_cols and no dispatch for transposedness needed 00057 source.append(" for (unsigned int row_cnt = 0; row_cnt < A_size1; ++row_cnt) \n"); 00058 source.append(" { \n"); 00059 source.append(" unsigned int row = A_size1 - 1 - row_cnt; \n"); 00060 } 00061 else //lower triangular solve 00062 { 00063 source.append(" for (unsigned int row = 0; row < A_size1; ++row) \n"); 00064 source.append(" { \n"); 00065 } 00066 00067 if (!unit_diagonal) 00068 { 00069 source.append(" barrier(CLK_GLOBAL_MEM_FENCE); \n"); 00070 source.append(" if (get_local_id(0) == 0) \n"); 00071 //Note: A is square, thus A_internal_rows == A_internal_cols and no dispatch for transposedness needed 00072 if (row_major_B && transpose_B) 00073 source.append(" B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (row * B_inc2 + B_start2)] /= "); 00074 else if (row_major_B && !transpose_B) 00075 source.append(" B[(row * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)] /= "); 00076 else if (!row_major_B && transpose_B) 00077 source.append(" B[(get_group_id(0) * B_inc1 + B_start1) + (row * B_inc2 + B_start2) * B_internal_size1] /= "); 00078 else if (!row_major_B && !transpose_B) 00079 source.append(" B[(row * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1] /= "); 00080 00081 if (row_major_A) 00082 source.append("A[(row * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2)]; \n"); 00083 else 00084 source.append("A[(row * A_inc1 + A_start1) + (row * A_inc2 + A_start2)*A_internal_size1]; \n"); 00085 } 00086 00087 source.append(" barrier(CLK_GLOBAL_MEM_FENCE); \n"); 00088 00089 if (row_major_B && transpose_B) 00090 source.append(" temp = B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (row * B_inc2 + B_start2)]; \n"); 00091 else if (row_major_B && !transpose_B) 00092 source.append(" temp = B[(row * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)]; \n"); 00093 else if (!row_major_B && transpose_B) 00094 source.append(" temp = B[(get_group_id(0) * B_inc1 + B_start1) + (row * B_inc2 + B_start2) * B_internal_size1]; \n"); 00095 else if (!row_major_B && !transpose_B) 00096 source.append(" temp = B[(row * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1]; \n"); 00097 00098 source.append(" //eliminate column of op(A) with index 'row' in parallel: \n"); 00099 if (upper_solve) 00100 source.append(" for (unsigned int elim = get_local_id(0); elim < row; elim += get_local_size(0)) \n"); 00101 else 00102 source.append(" for (unsigned int elim = row + get_local_id(0) + 1; elim < A_size1; elim += get_local_size(0)) \n"); 00103 00104 if (row_major_B && transpose_B) 00105 source.append(" B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (elim * B_inc2 + B_start2)] -= temp * "); 00106 else if (row_major_B && !transpose_B) 00107 source.append(" B[(elim * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)] -= temp * "); 00108 else if (!row_major_B && transpose_B) 00109 source.append(" B[(get_group_id(0) * B_inc1 + B_start1) + (elim * B_inc2 + B_start2) * B_internal_size1] -= temp * "); 00110 else if (!row_major_B && !transpose_B) 00111 source.append(" B[(elim * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1] -= temp * "); 00112 00113 if (row_major_A && transpose_A) 00114 source.append("A[(row * A_inc1 + A_start1) * A_internal_size2 + (elim * A_inc2 + A_start2)]; \n"); 00115 else if (row_major_A && !transpose_A) 00116 source.append("A[(elim * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2)]; \n"); 00117 else if (!row_major_A && transpose_A) 00118 source.append("A[(row * A_inc1 + A_start1) + (elim * A_inc2 + A_start2) * A_internal_size1]; \n"); 00119 else if (!row_major_A && !transpose_A) 00120 source.append("A[(elim * A_inc1 + A_start1) + (row * A_inc2 + A_start2) * A_internal_size1]; \n"); 00121 00122 source.append(" } \n"); 00123 source.append("} \n"); 00124 } 00125 00126 00127 // main kernel class 00133 template <class NumericT, typename F1, typename F2> 00134 struct matrix_solve 00135 { 00136 static std::string program_name() 00137 { 00138 return viennacl::ocl::type_to_string<NumericT>::apply() + "_matrix_solve_" + detail::type_to_string(F1()) + detail::type_to_string(F2()); 00139 } 00140 00141 static void init(viennacl::ocl::context & ctx) 00142 { 00143 viennacl::ocl::DOUBLE_PRECISION_CHECKER<NumericT>::apply(ctx); 00144 std::string numeric_string = viennacl::ocl::type_to_string<NumericT>::apply(); 00145 bool matrix_row_major = viennacl::is_row_major<F1>::value; 00146 bool rhs_row_major = viennacl::is_row_major<F2>::value; 00147 00148 00149 static std::map<cl_context, bool> init_done; 00150 if (!init_done[ctx.handle().get()]) 00151 { 00152 std::string source; 00153 source.reserve(8192); 00154 00155 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source); 00156 00157 // only generate for floating points (forces error for integers) 00158 if (numeric_string == "float" || numeric_string == "double") 00159 { 00160 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00161 false, false, false, false); 00162 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00163 false, false, false, true); 00164 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00165 false, false, true, false); 00166 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00167 false, false, true, true); 00168 00169 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00170 false, true, false, false); 00171 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00172 false, true, false, true); 00173 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00174 false, true, true, false); 00175 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00176 false, true, true, true); 00177 00178 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00179 true, false, false, false); 00180 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00181 true, false, false, true); 00182 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00183 true, false, true, false); 00184 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00185 true, false, true, true); 00186 00187 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00188 true, true, false, false); 00189 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00190 true, true, false, true); 00191 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00192 true, true, true, false); 00193 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major, 00194 true, true, true, true); 00195 } 00196 00197 std::string prog_name = program_name(); 00198 #ifdef VIENNACL_BUILD_INFO 00199 std::cout << "Creating program " << prog_name << std::endl; 00200 #endif 00201 ctx.add_program(source, prog_name); 00202 init_done[ctx.handle().get()] = true; 00203 } //if 00204 } //init 00205 }; 00206 00207 } // namespace kernels 00208 } // namespace opencl 00209 } // namespace linalg 00210 } // namespace viennacl 00211 #endif 00212