Skip to content

Commit

Permalink
Fix minor issues with previous code.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 21, 2024
1 parent 5a56df2 commit d78aa08
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,14 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
mps_k = [graph transposeTensor:mps_k dimension:-3 withDimension:-2 name:nil];
mps_v = [graph transposeTensor:mps_v dimension:-3 withDimension:-2 name:nil];
MPSGraphTensor* mps_o = [graph scaledDotProductAttentionWithQueryTensor:mps_q keyTensor:mps_k valueTensor:mps_v scale:scale name:nil];
mps_o = [graph transposeTensor:mps_o dimension:-3 withDimension:-2 name:nil];
[resultTensors addObject:mps_o];
[graph dump];
});
MPSGraphTensorData* data_q = ccv_nnc_mps_graph_tensor_data(q, qdim, qstride);
MPSGraphTensorData* data_k = ccv_nnc_mps_graph_tensor_data(k, kdim, kstride);
MPSGraphTensorData* data_v = ccv_nnc_mps_graph_tensor_data(v, vdim, vstride);
MPSGraphTensorData* data[] = {data_q, data_k, data_v};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &o, (int*[]){ o->info.dim }, (int*[]){ o->stride }, 1, 0);
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]], data[indices[2]]], &o, (int*[]){ o->info.dim }, (int*[]){ o->stride }, 1, 0);
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
return CCV_NNC_EXEC_SUCCESS;
/*
Expand Down
3 changes: 2 additions & 1 deletion test/int/nnc/mpsblas.tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -1530,7 +1530,8 @@ TEST_CASE("scaled dot product attention with mps")
#define num_short_trials 2
#define num_trials (num_long_trials + num_short_trials)

for (int trial = 0; trial < num_trials; ++trial) {
ccv_nnc_enable_flag(CCV_NNC_DISABLE_METAL_FLASH_ATTENTION);
for (int trial = 0; trial < 1; ++trial) {
int B_candidates[num_trials] = { 32, 32, 3, 2, 1 };
int R_candidates[num_trials] = { 128, 128, 61, 6, 2 };
int C_candidates[num_trials] = { 128, 128, 49, 2, 1 };
Expand Down

0 comments on commit d78aa08

Please sign in to comment.