Skip to content

Commit 0beb8db

Browse files
authored
ggml-vulkan: add SGN operator, auto-generate Vulkan.csv and ops.md (#20219)
1 parent b2f460b commit 0beb8db

File tree

5 files changed

+51
-5
lines changed

5 files changed

+51
-5
lines changed

docs/ops.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Legend:
4747
| FILL ||||||||||||
4848
| FLASH_ATTN_EXT || 🟡 || 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |||
4949
| FLOOR |||| 🟡 ||| 🟡 | 🟡 ||||
50+
| GATED_DELTA_NET ||||||||||||
5051
| GATED_LINEAR_ATTN ||||||||||||
5152
| GEGLU ||||| 🟡 ||| 🟡 ||||
5253
| GEGLU_ERF ||||| 🟡 ||| 🟡 ||||
@@ -92,7 +93,7 @@ Legend:
9293
| SCALE || 🟡 ||||||||||
9394
| SET ||||||| 🟡 |||||
9495
| SET_ROWS || 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |||
95-
| SGN |||| 🟡 | 🟡 ||| ||||
96+
| SGN |||| 🟡 | 🟡 ||| 🟡 ||||
9697
| SIGMOID |||| 🟡 | 🟡 | 🟡 || 🟡 ||||
9798
| SILU |||| 🟡 | 🟡 | 🟡 || 🟡 ||||
9899
| SILU_BACK ||||||||||||

docs/ops/Vulkan.csv

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"backend_name","op_name","op_params","test_mode","supported","error_message","backend_reg_name"
22
"Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
33
"Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
4-
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
5-
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
4+
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
5+
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
66
"Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
77
"Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
88
"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
@@ -85,8 +85,8 @@
8585
"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
8686
"Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
8787
"Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
88-
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
89-
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
88+
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
89+
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
9090
"Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
9191
"Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
9292
"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
@@ -13591,3 +13591,16 @@
1359113591
"Vulkan0","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","Vulkan"
1359213592
"Vulkan0","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
1359313593
"Vulkan0","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
13594+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
13595+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
13596+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
13597+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
13598+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","0","no","Vulkan"
13599+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","0","no","Vulkan"
13600+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","0","no","Vulkan"
13601+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
13602+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
13603+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
13604+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
13605+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","0","no","Vulkan"
13606+
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","0","no","Vulkan"

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,7 @@ struct vk_device_struct {
763763
vk_pipeline pipeline_ceil[2];
764764
vk_pipeline pipeline_floor[2];
765765
vk_pipeline pipeline_trunc[2];
766+
vk_pipeline pipeline_sgn[2];
766767

767768
vk_pipeline pipeline_add1_f16_f16;
768769
vk_pipeline pipeline_add1_f16_f32;
@@ -4393,6 +4394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
43934394
CREATE_UNARY(ceil)
43944395
CREATE_UNARY(floor)
43954396
CREATE_UNARY(trunc)
4397+
CREATE_UNARY(sgn)
43964398
#undef CREATE_UNARY
43974399

43984400
#define CREATE_UNARY_RTE(name) \
@@ -9281,6 +9283,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
92819283
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
92829284
case GGML_UNARY_OP_TRUNC:
92839285
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
9286+
case GGML_UNARY_OP_SGN:
9287+
return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16];
92849288
default:
92859289
break;
92869290
}
@@ -12875,6 +12879,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1287512879
case GGML_UNARY_OP_CEIL:
1287612880
case GGML_UNARY_OP_FLOOR:
1287712881
case GGML_UNARY_OP_TRUNC:
12882+
case GGML_UNARY_OP_SGN:
1287812883
ggml_vk_unary(ctx, compute_ctx, src0, node);
1287912884
break;
1288012885
case GGML_UNARY_OP_XIELU:
@@ -15004,6 +15009,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1500415009
case GGML_UNARY_OP_CEIL:
1500515010
case GGML_UNARY_OP_FLOOR:
1500615011
case GGML_UNARY_OP_TRUNC:
15012+
case GGML_UNARY_OP_SGN:
1500715013
return ggml_is_contiguous(op->src[0]) &&
1500815014
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1500915015
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -16170,6 +16176,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1617016176
case GGML_UNARY_OP_TRUNC:
1617116177
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
1617216178
break;
16179+
case GGML_UNARY_OP_SGN:
16180+
tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]);
16181+
break;
1617316182
default:
1617416183
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1617516184
GGML_ABORT("fatal error");
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#version 450
2+
3+
#include "generic_head.glsl"
4+
#include "types.glsl"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
data_d[i] = D_TYPE(sign(float(data_a[i])));
21+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,8 @@ void process_shaders() {
871871
string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
872872
string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
873873
string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
874+
string_to_spv("sgn_f16", "sgn.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
875+
string_to_spv("sgn_f32", "sgn.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
874876

875877
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
876878
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

0 commit comments

Comments
 (0)