libflame  revision_anchor
Functions
FLA_Gemv_external_gpu.c File Reference

(r)

Functions

FLA_Error FLA_Gemv_external_gpu (FLA_Trans transa, FLA_Obj alpha, FLA_Obj A, void *A_gpu, FLA_Obj x, void *x_gpu, FLA_Obj beta, FLA_Obj y, void *y_gpu)
 

Function Documentation

◆ FLA_Gemv_external_gpu()

FLA_Error FLA_Gemv_external_gpu ( FLA_Trans  transa,
FLA_Obj  alpha,
FLA_Obj  A,
void *  A_gpu,
FLA_Obj  x,
void *  x_gpu,
FLA_Obj  beta,
FLA_Obj  y,
void *  y_gpu 
)
18 {
19  FLA_Datatype datatype;
20  int m_A, n_A;
21  int ldim_A;
22  int inc_x;
23  int inc_y;
24  char blas_transa;
25 
26  if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
27  FLA_Gemv_check( transa, alpha, A, x, beta, y );
28 
29  if ( FLA_Obj_has_zero_dim( A ) ) return FLA_SUCCESS;
30 
31  datatype = FLA_Obj_datatype( A );
32 
33  m_A = FLA_Obj_length( A );
34  n_A = FLA_Obj_width( A );
35  ldim_A = FLA_Obj_length( A );
36 
37  inc_x = 1;
38  inc_y = 1;
39 
40  FLA_Param_map_flame_to_netlib_trans( transa, &blas_transa );
41 
42 
43  switch( datatype ){
44 
45  case FLA_FLOAT:
46  {
47  float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
48  float *buff_beta = ( float * ) FLA_FLOAT_PTR( beta );
49 
50  sgemv_( blas_transa,
51  m_A,
52  n_A,
53  *buff_alpha,
54  ( float * ) A_gpu, ldim_A,
55  ( float * ) x_gpu, inc_x,
56  *buff_beta,
57  ( float * ) y_gpu, inc_y );
58 
59  break;
60  }
61 
62  case FLA_DOUBLE:
63  {
64  double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
65  double *buff_beta = ( double * ) FLA_DOUBLE_PTR( beta );
66 
67  cublasDgemv( blas_transa,
68  m_A,
69  n_A,
70  *buff_alpha,
71  ( double * ) A_gpu, ldim_A,
72  ( double * ) x_gpu, inc_x,
73  *buff_beta,
74  ( double * ) y_gpu, inc_y );
75 
76  break;
77  }
78 
79  case FLA_COMPLEX:
80  {
81  cuComplex *buff_alpha = ( cuComplex * ) FLA_COMPLEX_PTR( alpha );
82  cuComplex *buff_beta = ( cuComplex * ) FLA_COMPLEX_PTR( beta );
83 
84  cublasCgemv( blas_transa,
85  m_A,
86  n_A,
87  *buff_alpha,
88  ( cuComplex * ) A_gpu, ldim_A,
89  ( cuComplex * ) x_gpu, inc_x,
90  *buff_beta,
91  ( cuComplex * ) y_gpu, inc_y );
92 
93  break;
94  }
95 
96  case FLA_DOUBLE_COMPLEX:
97  {
98  cuDoubleComplex *buff_alpha = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( alpha );
99  cuDoubleComplex *buff_beta = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( beta );
100 
101  cublasZgemv( blas_transa,
102  m_A,
103  n_A,
104  *buff_alpha,
105  ( cuDoubleComplex * ) A_gpu, ldim_A,
106  ( cuDoubleComplex * ) x_gpu, inc_x,
107  *buff_beta,
108  ( cuDoubleComplex * ) y_gpu, inc_y );
109 
110  break;
111  }
112 
113  }
114 
115  return FLA_SUCCESS;
116 }
FLA_Error FLA_Gemv_check(FLA_Trans transa, FLA_Obj alpha, FLA_Obj A, FLA_Obj x, FLA_Obj beta, FLA_Obj y)
Definition: FLA_Gemv_check.c:13
dim_t FLA_Obj_width(FLA_Obj obj)
Definition: FLA_Query.c:123
FLA_Bool FLA_Obj_has_zero_dim(FLA_Obj A)
Definition: FLA_Query.c:400
dim_t FLA_Obj_length(FLA_Obj obj)
Definition: FLA_Query.c:116
unsigned int FLA_Check_error_level(void)
Definition: FLA_Check.c:18
void FLA_Param_map_flame_to_netlib_trans(FLA_Trans trans, void *blas_trans)
Definition: FLA_Param.c:15
FLA_Datatype FLA_Obj_datatype(FLA_Obj obj)
Definition: FLA_Query.c:13
int FLA_Datatype
Definition: FLA_type_defs.h:49

References FLA_Check_error_level(), FLA_Gemv_check(), FLA_Obj_datatype(), FLA_Obj_has_zero_dim(), FLA_Obj_length(), FLA_Obj_width(), and FLA_Param_map_flame_to_netlib_trans().

Referenced by FLASH_Queue_exec_task_gpu().