Skip to content

Commit

Permalink
cuda : try cublasGemmStridedBatchedEx
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Oct 24, 2023
1 parent d415669 commit 25a0b90
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7134,8 +7134,21 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
}
}
#else
// use cublasGemmBatchedEx
{
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), ne02*src0->nb[2], // strideA
(char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), ne12*src1->nb[2]/2, // strideB
&beta_f16, (char *) dst_f16, CUDA_R_16F, ne01, ne12* dst->nb[2]/2, // strideC
ne13,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
printf("cublasGemmStridedBatchedEx\n");
} else {
// use cublasGemmBatchedEx
const int ne23 = ne12*ne13;

// TODO: avoid this alloc
Expand Down

0 comments on commit 25a0b90

Please sign in to comment.