libflame  revision_anchor
Functions
FLA_Syr2k_external_gpu.c File Reference

(r)

Functions

FLA_Error FLA_Syr2k_external_gpu (FLA_Uplo uplo, FLA_Trans trans, FLA_Obj alpha, FLA_Obj A, void *A_gpu, FLA_Obj B, void *B_gpu, FLA_Obj beta, FLA_Obj C, void *C_gpu)
 

Function Documentation

◆ FLA_Syr2k_external_gpu()

FLA_Error FLA_Syr2k_external_gpu ( FLA_Uplo  uplo,
FLA_Trans  trans,
FLA_Obj  alpha,
FLA_Obj  A,
void *  A_gpu,
FLA_Obj  B,
void *  B_gpu,
FLA_Obj  beta,
FLA_Obj  C,
void *  C_gpu 
)
18 {
19  FLA_Datatype datatype;
20  int k_AB;
21  int m_A, n_A;
22  int m_C;
23  int ldim_A;
24  int ldim_B;
25  int ldim_C;
26  char blas_uplo;
27  char blas_trans;
28 
29  if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
30  FLA_Syr2k_check( uplo, trans, alpha, A, B, beta, C );
31 
32  if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS;
33 
34  datatype = FLA_Obj_datatype( A );
35 
36  m_A = FLA_Obj_length( A );
37  n_A = FLA_Obj_width( A );
38  ldim_A = FLA_Obj_length( A );
39 
40  ldim_B = FLA_Obj_length( B );
41 
42  m_C = FLA_Obj_length( C );
43  ldim_C = FLA_Obj_length( C );
44 
45  if ( trans == FLA_NO_TRANSPOSE )
46  k_AB = n_A;
47  else
48  k_AB = m_A;
49 
50  FLA_Param_map_flame_to_netlib_uplo( uplo, &blas_uplo );
51  FLA_Param_map_flame_to_netlib_trans( trans, &blas_trans );
52 
53 
54  switch( datatype ){
55 
56  case FLA_FLOAT:
57  {
58  float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
59  float *buff_beta = ( float * ) FLA_FLOAT_PTR( beta );
60 
61  cublasSsyr2k( blas_uplo,
62  blas_trans,
63  m_C,
64  k_AB,
65  *buff_alpha,
66  ( float * ) A_gpu, ldim_A,
67  ( float * ) B_gpu, ldim_B,
68  *buff_beta,
69  ( float * ) C_gpu, ldim_C );
70 
71  break;
72  }
73 
74  case FLA_DOUBLE:
75  {
76  double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
77  double *buff_beta = ( double * ) FLA_DOUBLE_PTR( beta );
78 
79  cublasDsyr2k( blas_uplo,
80  blas_trans,
81  m_C,
82  k_AB,
83  *buff_alpha,
84  ( double * ) A_gpu, ldim_A,
85  ( double * ) B_gpu, ldim_B,
86  *buff_beta,
87  ( double * ) C_gpu, ldim_C );
88 
89  break;
90  }
91 
92  case FLA_COMPLEX:
93  {
94  cuComplex *buff_alpha = ( cuComplex * ) FLA_COMPLEX_PTR( alpha );
95  cuComplex *buff_beta = ( cuComplex * ) FLA_COMPLEX_PTR( beta );
96 
97  cublasCsyr2k( blas_uplo,
98  blas_trans,
99  m_C,
100  k_AB,
101  *buff_alpha,
102  ( cuComplex * ) A_gpu, ldim_A,
103  ( cuComplex * ) B_gpu, ldim_B,
104  *buff_beta,
105  ( cuComplex * ) C_gpu, ldim_C );
106 
107  break;
108  }
109 
110  case FLA_DOUBLE_COMPLEX:
111  {
112  cuDoubleComplex *buff_alpha = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( alpha );
113  cuDoubleComplex *buff_beta = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( beta );
114 
115  cublasZsyr2k( blas_uplo,
116  blas_trans,
117  m_C,
118  k_AB,
119  *buff_alpha,
120  ( cuDoubleComplex * ) A_gpu, ldim_A,
121  ( cuDoubleComplex * ) B_gpu, ldim_B,
122  *buff_beta,
123  ( cuDoubleComplex * ) C_gpu, ldim_C );
124 
125  break;
126  }
127 
128  }
129 
130  return FLA_SUCCESS;
131 }
FLA_Error FLA_Syr2k_check(FLA_Uplo uplo, FLA_Trans trans, FLA_Obj alpha, FLA_Obj A, FLA_Obj B, FLA_Obj beta, FLA_Obj C)
Definition: FLA_Syr2k_check.c:13
dim_t FLA_Obj_width(FLA_Obj obj)
Definition: FLA_Query.c:123
FLA_Bool FLA_Obj_has_zero_dim(FLA_Obj A)
Definition: FLA_Query.c:400
dim_t FLA_Obj_length(FLA_Obj obj)
Definition: FLA_Query.c:116
unsigned int FLA_Check_error_level(void)
Definition: FLA_Check.c:18
void FLA_Param_map_flame_to_netlib_trans(FLA_Trans trans, void *blas_trans)
Definition: FLA_Param.c:15
void FLA_Param_map_flame_to_netlib_uplo(FLA_Uplo uplo, void *blas_uplo)
Definition: FLA_Param.c:47
FLA_Datatype FLA_Obj_datatype(FLA_Obj obj)
Definition: FLA_Query.c:13
int FLA_Datatype
Definition: FLA_type_defs.h:49

References FLA_Check_error_level(), FLA_Obj_datatype(), FLA_Obj_has_zero_dim(), FLA_Obj_length(), FLA_Obj_width(), FLA_Param_map_flame_to_netlib_trans(), FLA_Param_map_flame_to_netlib_uplo(), and FLA_Syr2k_check().

Referenced by FLASH_Queue_exec_task_gpu().