libflame  revision_anchor
Functions
FLA_Gemm_external_gpu.c File Reference

(r)

Functions

FLA_Error FLA_Gemm_external_gpu (FLA_Trans transa, FLA_Trans transb, FLA_Obj alpha, FLA_Obj A, void *A_gpu, FLA_Obj B, void *B_gpu, FLA_Obj beta, FLA_Obj C, void *C_gpu)
 

Function Documentation

◆ FLA_Gemm_external_gpu()

FLA_Error FLA_Gemm_external_gpu ( FLA_Trans  transa,
FLA_Trans  transb,
FLA_Obj  alpha,
FLA_Obj  A,
void *  A_gpu,
FLA_Obj  B,
void *  B_gpu,
FLA_Obj  beta,
FLA_Obj  C,
void *  C_gpu 
)
18 {
19  FLA_Datatype datatype;
20  int k_AB;
21  int m_A, n_A;
22  int m_C, n_C;
23  int ldim_A;
24  int ldim_B;
25  int ldim_C;
26  char blas_transa;
27  char blas_transb;
28 
29  if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
30  FLA_Gemm_check( transa, transb, alpha, A, B, beta, C );
31 
32  if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS;
33 
34  if ( FLA_Obj_has_zero_dim( A ) || FLA_Obj_has_zero_dim( B ) )
35  {
36  FLA_Scal_external_gpu( beta, C, C_gpu );
37  return FLA_SUCCESS;
38  }
39 
40  datatype = FLA_Obj_datatype( A );
41 
42  m_A = FLA_Obj_length( A );
43  n_A = FLA_Obj_width( A );
44  ldim_A = FLA_Obj_length( A );
45 
46  ldim_B = FLA_Obj_length( B );
47 
48  m_C = FLA_Obj_length( C );
49  n_C = FLA_Obj_width( C );
50  ldim_C = FLA_Obj_length( C );
51 
52  if ( transa == FLA_NO_TRANSPOSE || transa == FLA_CONJ_NO_TRANSPOSE )
53  k_AB = n_A;
54  else
55  k_AB = m_A;
56 
57  FLA_Param_map_flame_to_netlib_trans( transa, &blas_transa );
58  FLA_Param_map_flame_to_netlib_trans( transb, &blas_transb );
59 
60 
61  switch( datatype ){
62 
63  case FLA_FLOAT:
64  {
65  float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
66  float *buff_beta = ( float * ) FLA_FLOAT_PTR( beta );
67 
68  cublasSgemm( blas_transa,
69  blas_transb,
70  m_C,
71  n_C,
72  k_AB,
73  *buff_alpha,
74  ( float * ) A_gpu, ldim_A,
75  ( float * ) B_gpu, ldim_B,
76  *buff_beta,
77  ( float * ) C_gpu, ldim_C );
78 
79  break;
80  }
81 
82  case FLA_DOUBLE:
83  {
84  double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
85  double *buff_beta = ( double * ) FLA_DOUBLE_PTR( beta );
86 
87  cublasDgemm( blas_transa,
88  blas_transb,
89  m_C,
90  n_C,
91  k_AB,
92  *buff_alpha,
93  ( double * ) A_gpu, ldim_A,
94  ( double * ) B_gpu, ldim_B,
95  *buff_beta,
96  ( double * ) C_gpu, ldim_C );
97 
98  break;
99  }
100 
101  case FLA_COMPLEX:
102  {
103  cuComplex *buff_alpha = ( cuComplex * ) FLA_COMPLEX_PTR( alpha );
104  cuComplex *buff_beta = ( cuComplex * ) FLA_COMPLEX_PTR( beta );
105 
106  cublasCgemm( blas_transa,
107  blas_transb,
108  m_C,
109  n_C,
110  k_AB,
111  *buff_alpha,
112  ( cuComplex * ) A_gpu, ldim_A,
113  ( cuComplex * ) B_gpu, ldim_B,
114  *buff_beta,
115  ( cuComplex * ) C_gpu, ldim_C );
116 
117  break;
118  }
119 
120  case FLA_DOUBLE_COMPLEX:
121  {
122  cuDoubleComplex *buff_alpha = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( alpha );
123  cuDoubleComplex *buff_beta = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( beta );
124 
125  cublasZgemm( blas_transa,
126  blas_transb,
127  m_C,
128  n_C,
129  k_AB,
130  *buff_alpha,
131  ( cuDoubleComplex * ) A_gpu, ldim_A,
132  ( cuDoubleComplex * ) B_gpu, ldim_B,
133  *buff_beta,
134  ( cuDoubleComplex * ) C_gpu, ldim_C );
135 
136  break;
137  }
138 
139  }
140 
141  return FLA_SUCCESS;
142 }
FLA_Error FLA_Gemm_check(FLA_Trans transa, FLA_Trans transb, FLA_Obj alpha, FLA_Obj A, FLA_Obj B, FLA_Obj beta, FLA_Obj C)
Definition: FLA_Gemm_check.c:13
FLA_Error FLA_Scal_external_gpu(FLA_Obj alpha, FLA_Obj A, void *A_gpu)
Definition: FLA_Scal_external_gpu.c:17
dim_t FLA_Obj_width(FLA_Obj obj)
Definition: FLA_Query.c:123
FLA_Bool FLA_Obj_has_zero_dim(FLA_Obj A)
Definition: FLA_Query.c:400
dim_t FLA_Obj_length(FLA_Obj obj)
Definition: FLA_Query.c:116
unsigned int FLA_Check_error_level(void)
Definition: FLA_Check.c:18
void FLA_Param_map_flame_to_netlib_trans(FLA_Trans trans, void *blas_trans)
Definition: FLA_Param.c:15
FLA_Datatype FLA_Obj_datatype(FLA_Obj obj)
Definition: FLA_Query.c:13
int FLA_Datatype
Definition: FLA_type_defs.h:49

References FLA_Check_error_level(), FLA_Gemm_check(), FLA_Obj_datatype(), FLA_Obj_has_zero_dim(), FLA_Obj_length(), FLA_Obj_width(), FLA_Param_map_flame_to_netlib_trans(), and FLA_Scal_external_gpu().

Referenced by FLASH_Queue_exec_task_gpu().