Skip to content

Commit

Permalink
Implement kv-cache
Browse files Browse the repository at this point in the history
Fixes #3.
  • Loading branch information
certik committed Mar 14, 2023
1 parent fa53655 commit 664276e
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 45 deletions.
133 changes: 89 additions & 44 deletions gpt2.f90
Original file line number Diff line number Diff line change
Expand Up @@ -61,46 +61,68 @@ function ffn(x, fc_w, fc_b, proj_w, proj_b) result(y)
y = linear(gelu(linear(x, fc_w, fc_b)), proj_w, proj_b)
end function

function attention(q, k, v, mask) result(y)
real(sp), intent(in) :: q(:,:), k(:,:), v(:,:), mask(:,:)
real(sp) :: y(size(v,1),size(q,2))
real(sp) :: tmp(size(k,2),size(q,2))
function attention(n_embd_head,n_seq,n_seq_x, q, k, v, mask) result(y)
integer, intent(in) :: n_embd_head, n_seq, n_seq_x
real(sp), intent(in) :: q(n_embd_head,n_seq_x), k(n_embd_head,n_seq), v(n_embd_head,n_seq), mask(n_seq,n_seq_x)
real(sp) :: y(n_embd_head,n_seq_x)
real(sp) :: tmp(n_seq,n_seq_x)
!tmp = matmul(transpose(k), q)
!call matmul_2d(transpose(k), q, tmp)
call matmul_2d_t(k, q, tmp)
call matmul_2d(v, softmax(tmp / sqrt(real(size(q,1),sp)) + mask), y)
call matmul_2d(v, softmax(tmp / sqrt(real(n_embd_head,sp)) + mask), y)
end function

function mha(n_seq, n_embd, x, attn_w, attn_b, proj_w, proj_b, n_head) &
function mha(n_seq, n_seq_x, n_embd, x, attn_w, attn_b, proj_w, proj_b, n_head, &
use_kv_cache, kv_cache) &
result(y)
integer, intent(in) :: n_seq, n_embd
real(sp), intent(in) :: x(n_embd,n_seq), &
integer, intent(in) :: n_seq, n_seq_x, n_embd
real(sp), intent(in) :: x(n_embd,n_seq_x), &
attn_w(3*n_embd,n_embd), attn_b(3*n_embd), &
proj_w(n_embd,n_embd), proj_b(n_embd)
real(sp), intent(inout) :: kv_cache(n_embd,n_seq,2)
integer, intent(in) :: n_head
real(sp) :: y(n_embd,n_seq)
real(sp) :: causal_mask(n_seq,n_seq)
real(sp) :: x2(3*n_embd,n_seq)
logical, intent(in) :: use_kv_cache
real(sp) :: y(n_embd,n_seq_x)
real(sp) :: causal_mask(n_seq,n_seq_x)
real(sp) :: x2(3*n_embd,n_seq_x)
integer :: i, j
! Mask
do j = 1, n_seq
do i = 1, n_seq
if (i > j) then
causal_mask(i,j) = -1e10_sp
else
causal_mask(i,j) = 0
end if
end do
end do
if (use_kv_cache) then
causal_mask = 0
else
do j = 1, n_seq
do i = 1, n_seq
if (i > j) then
causal_mask(i,j) = -1e10_sp
else
causal_mask(i,j) = 0
end if
end do
end do
end if
x2 = linear(x, attn_w, attn_b)
associate ( &
q => x2((1-1)*n_embd+1:1*n_embd,:), &
k => x2((2-1)*n_embd+1:2*n_embd,:), &
v => x2((3-1)*n_embd+1:3*n_embd,:) &
)
if (use_kv_cache) then
kv_cache(:,n_seq,1) = k(:,1)
kv_cache(:,n_seq,2) = v(:,1)
else
kv_cache(:,:,1) = k
kv_cache(:,:,2) = v
end if
end associate
associate ( &
q => x2((1-1)*n_embd+1:1*n_embd,:), &
k => kv_cache(:,:,1), &
v => kv_cache(:,:,2) &
)
! Perform attention over each head
do i = 1, n_head
y((i-1)*n_embd/n_head+1:i*n_embd/n_head,:) = attention( &
n_embd/n_head, n_seq, n_seq_x, &
q((i-1)*n_embd/n_head+1:i*n_embd/n_head,:), &
k((i-1)*n_embd/n_head+1:i*n_embd/n_head,:), &
v((i-1)*n_embd/n_head+1:i*n_embd/n_head,:), &
Expand All @@ -112,31 +134,32 @@ function mha(n_seq, n_embd, x, attn_w, attn_b, proj_w, proj_b, n_head) &
end function


function transformer_block(x, mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
function transformer_block(n_seq, n_seq_x, n_embd, x, mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, ln1_g, ln1_b, ln2_g, ln2_b, &
n_head) result(y)
real(sp), intent(in) :: x(:,:), &
n_head, use_kv_cache, kv_cache) result(y)
real(sp), intent(in) :: x(n_embd,n_seq_x), &
mlp_fc_w(:,:), mlp_fc_b(:), &
mlp_proj_w(:,:), mlp_proj_b(:), &
attn_w(:,:), attn_b(:), attn_proj_w(:,:), attn_proj_b(:), &
ln1_g(:), ln1_b(:), ln2_g(:), ln2_b(:)
integer, intent(in) :: n_head
real(sp) :: y(size(x,1),size(x,2))
integer :: n_seq, n_embd
n_embd = size(x,1)
n_seq = size(x,2)
y = x + mha(n_seq, n_embd, layer_norm(x, ln1_g, ln1_b, 1e-5_sp), &
attn_w, attn_b, attn_proj_w, attn_proj_b, n_head)
integer, intent(in) :: n_seq, n_seq_x, n_embd
real(sp) :: y(n_embd,n_seq_x)
logical, intent(in) :: use_kv_cache
real(sp), intent(inout) :: kv_cache(n_embd,n_seq,2)
y = x + mha(n_seq, n_seq_x, n_embd, layer_norm(x, ln1_g, ln1_b, 1e-5_sp), &
attn_w, attn_b, attn_proj_w, attn_proj_b, n_head, use_kv_cache, kv_cache)
y = y + ffn(layer_norm(y, ln2_g, ln2_b, 1e-5_sp), &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b)
end function

function gpt2(n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, input, &
function gpt2(n_vocab, n_ctx, n_seq, n_seq_x, n_embd, n_layer, n_head, input, &
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b) result(y)
integer, intent(in) :: n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, &
use_kv_cache, kv_cache) result(y)
integer, intent(in) :: n_vocab, n_ctx, n_seq, n_seq_x, n_embd, n_layer, n_head
integer, intent(in) :: input(n_seq)
real(sp), intent(in) :: wte(n_embd,n_vocab), wpe(n_embd,n_ctx), &
mlp_fc_w(4*n_embd,n_embd,n_layer), mlp_fc_b(4*n_embd,n_layer), &
Expand All @@ -146,19 +169,26 @@ function gpt2(n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, input, &
ln1_b(n_embd,n_layer), ln1_g(n_embd,n_layer), &
ln2_b(n_embd,n_layer), ln2_g(n_embd,n_layer), &
lnf_b(n_embd), lnf_g(n_embd)
real(sp) :: y(n_vocab,n_seq)
real(sp) :: x(n_embd,n_seq)
logical, intent(in) :: use_kv_cache
real(sp), intent(inout) :: kv_cache(n_embd,n_seq,2,n_layer)
real(sp) :: y(n_vocab,n_seq_x)
real(sp) :: x(n_embd,n_seq_x)
integer :: i
do i = 1, n_seq
x(:,i) = wte(:,input(i)+1) + wpe(:,i)
end do
if (use_kv_cache) then
i = n_seq
x(:,1) = wte(:,input(i)+1) + wpe(:,i)
else
do i = 1, n_seq
x(:,i) = wte(:,input(i)+1) + wpe(:,i)
end do
end if
do i = 1, n_layer
x = transformer_block(x, &
x = transformer_block(n_seq, n_seq_x, n_embd, x, &
mlp_fc_w(:,:,i), mlp_fc_b(:,i), &
mlp_proj_w(:,:,i), mlp_proj_b(:,i), &
attn_w(:,:,i), attn_b(:,i), attn_proj_w(:,:,i), attn_proj_b(:,i), &
ln1_g(:,i), ln1_b(:,i), ln2_g(:,i), ln2_b(:,i), &
n_head)
n_head, use_kv_cache, kv_cache(:,:,:,i))
end do
x = layer_norm(x, lnf_g, lnf_b, 1e-5)
!y = matmul(transpose(wte), x)
Expand All @@ -170,7 +200,7 @@ function generate(n_tokens_to_generate, &
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b) result(output)
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache) result(output)
integer, intent(in) :: n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, &
n_tokens_to_generate
integer, intent(in) :: input(n_seq)
Expand All @@ -182,22 +212,37 @@ function generate(n_tokens_to_generate, &
ln1_b(n_embd,n_layer), ln1_g(n_embd,n_layer), &
ln2_b(n_embd,n_layer), ln2_g(n_embd,n_layer), &
lnf_b(n_embd), lnf_g(n_embd)
logical, intent(in) :: use_cache
integer :: output(n_tokens_to_generate)
real(sp), allocatable :: logits(:,:)
integer :: i
integer :: n_seq2, n_seq_x
integer :: next_id
integer, allocatable :: input2(:)
logical :: use_kv_cache
real(sp) :: kv_cache(n_embd,n_seq+n_tokens_to_generate,2,n_layer)
allocate(input2(size(input)))
input2 = input
do i = 1, n_tokens_to_generate
allocate(logits(n_vocab, size(input2)))
logits = gpt2(n_vocab, n_ctx, size(input2), n_embd, n_layer, n_head, &
if (use_cache) then
use_kv_cache = (i > 1) ! Use cache for subsequent tokens
else
use_kv_cache = .false.
end if
n_seq2 = size(input2)
if (use_kv_cache) then
n_seq_x = 1
else
n_seq_x = n_seq2
end if
allocate(logits(n_vocab, n_seq_x))
logits = gpt2(n_vocab, n_ctx, n_seq2, n_seq_x, n_embd, n_layer, n_head, &
input2, &
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b)
next_id = maxloc(logits(:,size(logits,2)), dim=1)-1
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_kv_cache, kv_cache(:,:n_seq2,:,:))
next_id = maxloc(logits(:,n_seq_x), dim=1)-1
print *, i, next_id
input2 = [input2, next_id]
deallocate(logits)
Expand Down
4 changes: 3 additions & 1 deletion main.f90
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ program gpt2
character(:), allocatable :: output_txt
real(dp) :: t1, t2, t1o, t2o
integer :: u
logical :: use_cache

! Load the model
print "(a)", "Loading the model..."
Expand Down Expand Up @@ -86,13 +87,14 @@ program gpt2
print "(a)", "Running model..."
call cpu_time(t1)
t1o = omp_get_wtime()
use_cache = .true.
output = generate(n_tokens_to_generate, n_vocab, n_ctx, size(input), n_embd, &
n_layer, n_head, &
input, &
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b)
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache)
t2o = omp_get_wtime()
call cpu_time(t2)
print "(a,f8.3,a,f4.2,a)", " done. Time:", t2o-t1o, "s (", (t2-t1)/(t2o-t1o), "x)"
Expand Down

0 comments on commit 664276e

Please sign in to comment.