From 25a0b9080d8666b5c477bb9563368cdffa44c4ab Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Oct 2023 13:34:54 +0300 Subject: [PATCH] cuda : try cublasGemmStridedBatchedEx --- ggml-cuda.cu | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e2dea9eab4be0d..148a657703f916 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7134,8 +7134,21 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const } } #else - // use cublasGemmBatchedEx - { + if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 + // use cublasGemmStridedBatchedEx + CUBLAS_CHECK( + cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), ne02*src0->nb[2], // strideA + (char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), ne12*src1->nb[2]/2, // strideB + &beta_f16, (char *) dst_f16, CUDA_R_16F, ne01, ne12* dst->nb[2]/2, // strideC + ne13, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + printf("cublasGemmStridedBatchedEx\n"); + } else { + // use cublasGemmBatchedEx const int ne23 = ne12*ne13; // TODO: avoid this alloc