@@ -50,7 +50,7 @@ __global__ void repkv_forward_kernel1(floatX* replicated_qkv,
5050
5151__global__ void repkv_backward_kernel1 (floatX* dinp, const floatX* dout,
5252 int B, int N, int NH, int replicate_factor, int HD) {
53- // we have a single tensor dout of shapae of (B, N 3 * NH * HD)
53+ // we have a single tensor dout of shape of (B, N 3 * NH * HD)
5454 // we want to reduce sum (for K and V) into (B, N, (NH + 2*(NH/replicate_factor)) * HD)
5555 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
5656 if (idx >= B * N * 3 * NH * HD) { return ;}
@@ -111,11 +111,11 @@ void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_
111111}
112112
113113void repkv_backward (floatX* dinp, const floatX* dout,
114- const int B, const int T, const int NH, const int NH_KV, const int d) {
114+ const int B, const int T, const int NH, const int NH_KV, const int d, cudaStream_t stream ) {
115115 const int block_size = 128 ;
116116 int total_threads = B * T * (3 * NH) * d;
117117 int num_blocks = CEIL_DIV (total_threads, block_size);
118118 int replicate_factor = NH / NH_KV;
119- repkv_backward_kernel1<<<num_blocks, block_size>>> (dinp, dout, B, T, NH, replicate_factor, d);
119+ repkv_backward_kernel1<<<num_blocks, block_size, 0 , stream >>> (dinp, dout, B, T, NH, replicate_factor, d);
120120 cudaCheck (cudaGetLastError ());
121121}
0 commit comments