Skip to content

Commit c125883

Browse files
authored
ggml webgpu: ops support for qwen3.5 (SET, TRI_SOLVE, SSM_CONV, GATED_DELTA_NET) + GET_ROWS optimization (#20687)
* Implement l2_norm, set, tri * Add DIAG/SOLVE_TRI * Add SSM_CONV * Better get_rows and gated_delta_net to support qwen3.5 * Clean up, update ops.md * Fix binding_index type for wasm * Fix read write annotations * cleanups
1 parent 922b90e commit c125883

File tree

10 files changed

+2872
-6904
lines changed

10 files changed

+2872
-6904
lines changed

docs/ops.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Legend:
4747
| FILL ||||||||||||
4848
| FLASH_ATTN_EXT || 🟡 || 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |||
4949
| FLOOR |||| 🟡 ||| 🟡 | 🟡 ||||
50-
| GATED_DELTA_NET ||||||||| |||
50+
| GATED_DELTA_NET ||||||||| |||
5151
| GATED_LINEAR_ATTN ||||||||||||
5252
| GEGLU ||||| 🟡 ||| 🟡 ||||
5353
| GEGLU_ERF ||||| 🟡 ||| 🟡 ||||
@@ -91,7 +91,7 @@ Legend:
9191
| RWKV_WKV6 ||||||||||||
9292
| RWKV_WKV7 ||||||||||||
9393
| SCALE || 🟡 ||||||||||
94-
| SET ||||||| 🟡 || |||
94+
| SET ||||||| 🟡 || |||
9595
| SET_ROWS || 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |||
9696
| SGN |||| 🟡 | 🟡 ||| 🟡 ||||
9797
| SIGMOID |||| 🟡 | 🟡 | 🟡 || 🟡 ||||
@@ -101,10 +101,10 @@ Legend:
101101
| SOFTPLUS |||| 🟡 | 🟡 ||| 🟡 ||||
102102
| SOFT_MAX || 🟡 ||||||||||
103103
| SOFT_MAX_BACK ||| 🟡 | 🟡 ||| 🟡 |||||
104-
| SOLVE_TRI |||| 🟡 ||||| |||
104+
| SOLVE_TRI |||| 🟡 ||||| |||
105105
| SQR ||||| 🟡 || 🟡 | 🟡 ||||
106106
| SQRT ||||| 🟡 || 🟡 | 🟡 ||||
107-
| SSM_CONV ||||||||| |||
107+
| SSM_CONV ||||||||| |||
108108
| SSM_SCAN |||||||| 🟡 ||||
109109
| STEP |||| 🟡 | 🟡 ||| 🟡 ||||
110110
| SUB ||||| 🟡 |||||||

docs/ops/WebGPU.csv

Lines changed: 1834 additions & 6880 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 269 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions {
9595
uint32_t wg_size = 0;
9696
};
9797

98+
struct ggml_webgpu_ssm_conv_shader_decisions {
99+
uint32_t block_size;
100+
uint32_t tokens_per_wg;
101+
};
102+
98103
/** Argsort **/
99104

100105
struct ggml_webgpu_argsort_shader_lib_context {
@@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions {
131136
uint32_t wg_size;
132137
};
133138

139+
/** Set **/
140+
141+
struct ggml_webgpu_set_pipeline_key {
142+
ggml_type type;
143+
bool inplace;
144+
145+
bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
146+
return type == other.type && inplace == other.inplace;
147+
}
148+
};
149+
150+
struct ggml_webgpu_set_pipeline_key_hash {
151+
size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
152+
size_t seed = 0;
153+
ggml_webgpu_hash_combine(seed, key.type);
154+
ggml_webgpu_hash_combine(seed, key.inplace);
155+
return seed;
156+
}
157+
};
158+
134159
/** Get Rows **/
135160

136161
struct ggml_webgpu_get_rows_pipeline_key {
@@ -186,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash {
186211
}
187212
};
188213

214+
/** Solve Tri **/
215+
struct ggml_webgpu_solve_tri_pipeline_key {
216+
int type;
217+
int n;
218+
int k;
219+
220+
bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
221+
return type == other.type && n == other.n && k == other.k;
222+
}
223+
};
224+
225+
struct ggml_webgpu_solve_tri_pipeline_key_hash {
226+
size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
227+
size_t seed = 0;
228+
ggml_webgpu_hash_combine(seed, key.type);
229+
ggml_webgpu_hash_combine(seed, key.n);
230+
ggml_webgpu_hash_combine(seed, key.k);
231+
return seed;
232+
}
233+
};
234+
235+
/** SSM Conv **/
236+
struct ggml_webgpu_ssm_conv_pipeline_key {
237+
int type;
238+
int vectorized;
239+
240+
bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
241+
return type == other.type && vectorized == other.vectorized;
242+
}
243+
};
244+
245+
/** Gated Delta Net **/
246+
struct ggml_webgpu_gated_delta_net_pipeline_key {
247+
int type;
248+
int s_v;
249+
int kda;
250+
251+
bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
252+
return type == other.type && s_v == other.s_v && kda == other.kda;
253+
}
254+
};
255+
256+
struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
257+
size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
258+
size_t seed = 0;
259+
ggml_webgpu_hash_combine(seed, key.type);
260+
ggml_webgpu_hash_combine(seed, key.s_v);
261+
ggml_webgpu_hash_combine(seed, key.kda);
262+
return seed;
263+
}
264+
};
265+
266+
struct ggml_webgpu_ssm_conv_pipeline_key_hash {
267+
size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
268+
size_t seed = 0;
269+
ggml_webgpu_hash_combine(seed, key.type);
270+
ggml_webgpu_hash_combine(seed, key.vectorized);
271+
return seed;
272+
}
273+
};
274+
189275
/** Scale **/
190276

191277
struct ggml_webgpu_scale_pipeline_key {
@@ -466,14 +552,22 @@ class ggml_webgpu_shader_lib {
466552
unary_pipelines; // type/op/inplace
467553
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
468554
scale_pipelines; // inplace
555+
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
556+
solve_tri_pipelines; // type
557+
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
558+
ssm_conv_pipelines; // type/vectorized
559+
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
560+
webgpu_pipeline,
561+
ggml_webgpu_gated_delta_net_pipeline_key_hash>
562+
gated_delta_net_pipelines; // type/S_v/kda
469563
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
470-
pad_pipelines; // circular/non-circular
564+
pad_pipelines; // circular/non-circular
471565
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
472-
binary_pipelines; // type/op/inplace/overlap
566+
binary_pipelines; // type/op/inplace/overlap
473567
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
474-
concat_pipelines; // type
568+
concat_pipelines; // type
475569
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
476-
repeat_pipelines; // type
570+
repeat_pipelines; // type
477571
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
478572
flash_attn_pipelines;
479573
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@@ -487,6 +581,7 @@ class ggml_webgpu_shader_lib {
487581

488582
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
489583
set_rows_pipelines;
584+
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
490585

491586
public:
492587
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@@ -519,11 +614,11 @@ class ggml_webgpu_shader_lib {
519614

520615
switch (key.op) {
521616
case GGML_OP_RMS_NORM:
522-
defines.push_back("OP_RMS_NORM");
617+
defines.push_back("RMS_NORM");
523618
variant = "rms_norm";
524619
break;
525620
case GGML_OP_L2_NORM:
526-
defines.push_back("OP_L2_NORM");
621+
defines.push_back("L2_NORM");
527622
variant = "l2_norm";
528623
break;
529624
default:
@@ -535,8 +630,9 @@ class ggml_webgpu_shader_lib {
535630
variant += "_inplace";
536631
}
537632

538-
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
539-
633+
const uint32_t row_norm_wg_size = 128u;
634+
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
635+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
540636
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
541637
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
542638
return row_norm_pipelines[key];
@@ -609,6 +705,46 @@ class ggml_webgpu_shader_lib {
609705
return set_rows_pipelines[key];
610706
}
611707

708+
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
709+
ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
710+
711+
auto it = set_pipelines.find(key);
712+
if (it != set_pipelines.end()) {
713+
return it->second;
714+
}
715+
716+
std::vector<std::string> defines;
717+
std::string variant = "set";
718+
719+
switch (key.type) {
720+
case GGML_TYPE_F32:
721+
defines.push_back("TYPE_F32");
722+
variant += "_f32";
723+
break;
724+
case GGML_TYPE_I32:
725+
defines.push_back("TYPE_I32");
726+
variant += "_i32";
727+
break;
728+
default:
729+
GGML_ABORT("Unsupported type for set shader");
730+
}
731+
732+
if (key.inplace) {
733+
defines.push_back("INPLACE");
734+
variant += "_inplace";
735+
}
736+
737+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
738+
739+
auto processed = preprocessor.preprocess(wgsl_set, defines);
740+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
741+
decisions->wg_size = context.max_wg_size;
742+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
743+
pipeline.context = decisions;
744+
set_pipelines[key] = pipeline;
745+
return set_pipelines[key];
746+
}
747+
612748
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
613749
auto it = cumsum_pipelines.find(1);
614750
if (it != cumsum_pipelines.end()) {
@@ -695,6 +831,7 @@ class ggml_webgpu_shader_lib {
695831

696832
switch (key.src_type) {
697833
case GGML_TYPE_F32:
834+
defines.push_back("FLOAT_PARALLEL");
698835
if (key.vectorized) {
699836
defines.push_back("F32_VEC");
700837
defines.push_back("SRC_TYPE=vec4<f32>");
@@ -709,13 +846,15 @@ class ggml_webgpu_shader_lib {
709846
variant += "_f32";
710847
break;
711848
case GGML_TYPE_F16:
849+
defines.push_back("FLOAT_PARALLEL");
712850
defines.push_back("F16");
713851
defines.push_back("SRC_TYPE=f16");
714852
defines.push_back("DST_TYPE=f32");
715853
defines.push_back("BLOCK_SIZE=1u");
716854
variant += "_f16";
717855
break;
718856
case GGML_TYPE_I32:
857+
defines.push_back("FLOAT_PARALLEL");
719858
defines.push_back("I32");
720859
defines.push_back("SRC_TYPE=i32");
721860
defines.push_back("DST_TYPE=i32");
@@ -794,6 +933,128 @@ class ggml_webgpu_shader_lib {
794933
return scale_pipelines[key];
795934
}
796935

936+
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
937+
ggml_webgpu_solve_tri_pipeline_key key = {
938+
.type = context.dst->type,
939+
.n = (int) context.src0->ne[0],
940+
.k = (int) context.src1->ne[0],
941+
};
942+
943+
auto it = solve_tri_pipelines.find(key);
944+
if (it != solve_tri_pipelines.end()) {
945+
return it->second;
946+
}
947+
948+
std::vector<std::string> defines;
949+
std::string variant = "solve_tri";
950+
951+
switch (key.type) {
952+
case GGML_TYPE_F32:
953+
variant += "_f32";
954+
break;
955+
default:
956+
GGML_ABORT("Unsupported type for solve_tri shader");
957+
}
958+
959+
const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
960+
const uint32_t k_tile = wg_size;
961+
const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
962+
const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
963+
964+
defines.push_back(std::string("N=") + std::to_string(key.n));
965+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
966+
defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
967+
defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
968+
969+
auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
970+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
971+
decisions->wg_size = wg_size;
972+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
973+
pipeline.context = decisions;
974+
solve_tri_pipelines[key] = pipeline;
975+
return solve_tri_pipelines[key];
976+
}
977+
978+
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
979+
ggml_webgpu_ssm_conv_pipeline_key key = {
980+
.type = context.dst->type,
981+
.vectorized = context.src1->ne[0] == 4,
982+
};
983+
984+
auto it = ssm_conv_pipelines.find(key);
985+
if (it != ssm_conv_pipelines.end()) {
986+
return it->second;
987+
}
988+
989+
std::vector<std::string> defines;
990+
std::string variant = "ssm_conv";
991+
992+
switch (key.type) {
993+
case GGML_TYPE_F32:
994+
variant += "_f32";
995+
break;
996+
default:
997+
GGML_ABORT("Unsupported type for ssm_conv shader");
998+
}
999+
1000+
if (key.vectorized) {
1001+
defines.push_back("VECTORIZED");
1002+
variant += "_vec4";
1003+
}
1004+
1005+
constexpr uint32_t block_size = 32u;
1006+
constexpr uint32_t tokens_per_wg = 8u;
1007+
1008+
defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
1009+
defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
1010+
1011+
auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
1012+
auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
1013+
decisions->block_size = block_size;
1014+
decisions->tokens_per_wg = tokens_per_wg;
1015+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1016+
pipeline.context = decisions;
1017+
ssm_conv_pipelines[key] = pipeline;
1018+
return ssm_conv_pipelines[key];
1019+
}
1020+
1021+
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
1022+
ggml_webgpu_gated_delta_net_pipeline_key key = {
1023+
.type = context.dst->type,
1024+
.s_v = (int) context.src2->ne[0],
1025+
.kda = context.src3->ne[0] == context.src2->ne[0],
1026+
};
1027+
1028+
auto it = gated_delta_net_pipelines.find(key);
1029+
if (it != gated_delta_net_pipelines.end()) {
1030+
return it->second;
1031+
}
1032+
1033+
std::vector<std::string> defines;
1034+
std::string variant = "gated_delta_net";
1035+
1036+
switch (key.type) {
1037+
case GGML_TYPE_F32:
1038+
variant += "_f32";
1039+
break;
1040+
default:
1041+
GGML_ABORT("Unsupported type for gated_delta_net shader");
1042+
}
1043+
1044+
if (key.kda) {
1045+
defines.push_back("KDA");
1046+
variant += "_kda";
1047+
}
1048+
1049+
defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
1050+
defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
1051+
1052+
auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
1053+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1054+
gated_delta_net_pipelines[key] = pipeline;
1055+
return gated_delta_net_pipelines[key];
1056+
}
1057+
7971058
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
7981059
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
7991060

0 commit comments

Comments
 (0)