libflame  revision_anchor
Functions
FLA_Syrk_external_gpu.c File Reference

(r)

Functions

FLA_Error FLA_Syrk_external_gpu (FLA_Uplo uplo, FLA_Trans trans, FLA_Obj alpha, FLA_Obj A, void *A_gpu, FLA_Obj beta, FLA_Obj C, void *C_gpu)
 

Function Documentation

◆ FLA_Syrk_external_gpu()

FLA_Error FLA_Syrk_external_gpu ( FLA_Uplo  uplo,
FLA_Trans  trans,
FLA_Obj  alpha,
FLA_Obj  A,
void *  A_gpu,
FLA_Obj  beta,
FLA_Obj  C,
void *  C_gpu 
)
18 {
19  FLA_Datatype datatype;
20  int k_A;
21  int m_A, n_A;
22  int m_C;
23  int ldim_A;
24  int ldim_C;
25  char blas_uplo;
26  char blas_trans;
27 
28  if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
29  FLA_Syrk_check( uplo, trans, alpha, A, beta, C );
30 
31  if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS;
32 
33  datatype = FLA_Obj_datatype( A );
34 
35  m_A = FLA_Obj_length( A );
36  n_A = FLA_Obj_width( A );
37  ldim_A = FLA_Obj_length( A );
38 
39  m_C = FLA_Obj_length( C );
40  ldim_C = FLA_Obj_length( C );
41 
42  if ( trans == FLA_NO_TRANSPOSE )
43  k_A = n_A;
44  else
45  k_A = m_A;
46 
47  FLA_Param_map_flame_to_netlib_uplo( uplo, &blas_uplo );
48  FLA_Param_map_flame_to_netlib_trans( trans, &blas_trans );
49 
50 
51  switch( datatype ){
52 
53  case FLA_FLOAT:
54  {
55  float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
56  float *buff_beta = ( float * ) FLA_FLOAT_PTR( beta );
57 
58  cublasSsyrk( blas_uplo,
59  blas_trans,
60  m_C,
61  k_A,
62  *buff_alpha,
63  ( float * ) A_gpu, ldim_A,
64  *buff_beta,
65  ( float * ) C_gpu, ldim_C );
66 
67  break;
68  }
69 
70  case FLA_DOUBLE:
71  {
72  double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
73  double *buff_beta = ( double * ) FLA_DOUBLE_PTR( beta );
74 
75  cublasDsyrk( blas_uplo,
76  blas_trans,
77  m_C,
78  k_A,
79  *buff_alpha,
80  ( double * ) A_gpu, ldim_A,
81  *buff_beta,
82  ( double * ) C_gpu, ldim_C );
83 
84  break;
85  }
86 
87  case FLA_COMPLEX:
88  {
89  cuComplex *buff_alpha = ( cuComplex * ) FLA_COMPLEX_PTR( alpha );
90  cuComplex *buff_beta = ( cuComplex * ) FLA_COMPLEX_PTR( beta );
91 
92  cublasCsyrk( blas_uplo,
93  blas_trans,
94  m_C,
95  k_A,
96  *buff_alpha,
97  ( cuComplex * ) A_gpu, ldim_A,
98  *buff_beta,
99  ( cuComplex * ) C_gpu, ldim_C );
100 
101  break;
102  }
103 
104  case FLA_DOUBLE_COMPLEX:
105  {
106  cuDoubleComplex *buff_alpha = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( alpha );
107  cuDoubleComplex *buff_beta = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( beta );
108 
109  cublasZsyrk( blas_uplo,
110  blas_trans,
111  m_C,
112  k_A,
113  *buff_alpha,
114  ( cuDoubleComplex * ) A_gpu, ldim_A,
115  *buff_beta,
116  ( cuDoubleComplex * ) C_gpu, ldim_C );
117 
118  break;
119  }
120 
121  }
122 
123  return FLA_SUCCESS;
124 }
FLA_Error FLA_Syrk_check(FLA_Uplo uplo, FLA_Trans trans, FLA_Obj alpha, FLA_Obj A, FLA_Obj beta, FLA_Obj C)
Definition: FLA_Syrk_check.c:13
dim_t FLA_Obj_width(FLA_Obj obj)
Definition: FLA_Query.c:123
FLA_Bool FLA_Obj_has_zero_dim(FLA_Obj A)
Definition: FLA_Query.c:400
dim_t FLA_Obj_length(FLA_Obj obj)
Definition: FLA_Query.c:116
unsigned int FLA_Check_error_level(void)
Definition: FLA_Check.c:18
void FLA_Param_map_flame_to_netlib_trans(FLA_Trans trans, void *blas_trans)
Definition: FLA_Param.c:15
void FLA_Param_map_flame_to_netlib_uplo(FLA_Uplo uplo, void *blas_uplo)
Definition: FLA_Param.c:47
FLA_Datatype FLA_Obj_datatype(FLA_Obj obj)
Definition: FLA_Query.c:13
int FLA_Datatype
Definition: FLA_type_defs.h:49

References FLA_Check_error_level(), FLA_Obj_datatype(), FLA_Obj_has_zero_dim(), FLA_Obj_length(), FLA_Obj_width(), FLA_Param_map_flame_to_netlib_trans(), FLA_Param_map_flame_to_netlib_uplo(), and FLA_Syrk_check().

Referenced by FLASH_Queue_exec_task_gpu().