@@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions {
9595 uint32_t wg_size = 0 ;
9696};
9797
98+ struct ggml_webgpu_ssm_conv_shader_decisions {
99+ uint32_t block_size;
100+ uint32_t tokens_per_wg;
101+ };
102+
98103/* * Argsort **/
99104
100105struct ggml_webgpu_argsort_shader_lib_context {
@@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions {
131136 uint32_t wg_size;
132137};
133138
139+ /* * Set **/
140+
141+ struct ggml_webgpu_set_pipeline_key {
142+ ggml_type type;
143+ bool inplace;
144+
145+ bool operator ==(const ggml_webgpu_set_pipeline_key & other) const {
146+ return type == other.type && inplace == other.inplace ;
147+ }
148+ };
149+
150+ struct ggml_webgpu_set_pipeline_key_hash {
151+ size_t operator ()(const ggml_webgpu_set_pipeline_key & key) const {
152+ size_t seed = 0 ;
153+ ggml_webgpu_hash_combine (seed, key.type );
154+ ggml_webgpu_hash_combine (seed, key.inplace );
155+ return seed;
156+ }
157+ };
158+
134159/* * Get Rows **/
135160
136161struct ggml_webgpu_get_rows_pipeline_key {
@@ -186,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash {
186211 }
187212};
188213
214+ /* * Solve Tri **/
215+ struct ggml_webgpu_solve_tri_pipeline_key {
216+ int type;
217+ int n;
218+ int k;
219+
220+ bool operator ==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
221+ return type == other.type && n == other.n && k == other.k ;
222+ }
223+ };
224+
225+ struct ggml_webgpu_solve_tri_pipeline_key_hash {
226+ size_t operator ()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
227+ size_t seed = 0 ;
228+ ggml_webgpu_hash_combine (seed, key.type );
229+ ggml_webgpu_hash_combine (seed, key.n );
230+ ggml_webgpu_hash_combine (seed, key.k );
231+ return seed;
232+ }
233+ };
234+
235+ /* * SSM Conv **/
236+ struct ggml_webgpu_ssm_conv_pipeline_key {
237+ int type;
238+ int vectorized;
239+
240+ bool operator ==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
241+ return type == other.type && vectorized == other.vectorized ;
242+ }
243+ };
244+
245+ /* * Gated Delta Net **/
246+ struct ggml_webgpu_gated_delta_net_pipeline_key {
247+ int type;
248+ int s_v;
249+ int kda;
250+
251+ bool operator ==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
252+ return type == other.type && s_v == other.s_v && kda == other.kda ;
253+ }
254+ };
255+
256+ struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
257+ size_t operator ()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
258+ size_t seed = 0 ;
259+ ggml_webgpu_hash_combine (seed, key.type );
260+ ggml_webgpu_hash_combine (seed, key.s_v );
261+ ggml_webgpu_hash_combine (seed, key.kda );
262+ return seed;
263+ }
264+ };
265+
266+ struct ggml_webgpu_ssm_conv_pipeline_key_hash {
267+ size_t operator ()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
268+ size_t seed = 0 ;
269+ ggml_webgpu_hash_combine (seed, key.type );
270+ ggml_webgpu_hash_combine (seed, key.vectorized );
271+ return seed;
272+ }
273+ };
274+
189275/* * Scale **/
190276
191277struct ggml_webgpu_scale_pipeline_key {
@@ -466,14 +552,22 @@ class ggml_webgpu_shader_lib {
466552 unary_pipelines; // type/op/inplace
467553 std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
468554 scale_pipelines; // inplace
555+ std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
556+ solve_tri_pipelines; // type
557+ std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
558+ ssm_conv_pipelines; // type/vectorized
559+ std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
560+ webgpu_pipeline,
561+ ggml_webgpu_gated_delta_net_pipeline_key_hash>
562+ gated_delta_net_pipelines; // type/S_v/kda
469563 std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
470- pad_pipelines; // circular/non-circular
564+ pad_pipelines; // circular/non-circular
471565 std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
472- binary_pipelines; // type/op/inplace/overlap
566+ binary_pipelines; // type/op/inplace/overlap
473567 std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
474- concat_pipelines; // type
568+ concat_pipelines; // type
475569 std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
476- repeat_pipelines; // type
570+ repeat_pipelines; // type
477571 std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
478572 flash_attn_pipelines;
479573 std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@@ -487,6 +581,7 @@ class ggml_webgpu_shader_lib {
487581
488582 std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
489583 set_rows_pipelines;
584+ std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
490585
491586 public:
492587 ggml_webgpu_shader_lib (wgpu::Device device) { this ->device = device; }
@@ -519,11 +614,11 @@ class ggml_webgpu_shader_lib {
519614
520615 switch (key.op ) {
521616 case GGML_OP_RMS_NORM:
522- defines.push_back (" OP_RMS_NORM " );
617+ defines.push_back (" RMS_NORM " );
523618 variant = " rms_norm" ;
524619 break ;
525620 case GGML_OP_L2_NORM:
526- defines.push_back (" OP_L2_NORM " );
621+ defines.push_back (" L2_NORM " );
527622 variant = " l2_norm" ;
528623 break ;
529624 default :
@@ -535,8 +630,9 @@ class ggml_webgpu_shader_lib {
535630 variant += " _inplace" ;
536631 }
537632
538- defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (context.max_wg_size ));
539-
633+ const uint32_t row_norm_wg_size = 128u ;
634+ uint32_t wg_size = std::min (context.max_wg_size , row_norm_wg_size);
635+ defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (wg_size));
540636 auto processed = preprocessor.preprocess (wgsl_row_norm, defines);
541637 row_norm_pipelines[key] = ggml_webgpu_create_pipeline (device, processed, variant);
542638 return row_norm_pipelines[key];
@@ -609,6 +705,46 @@ class ggml_webgpu_shader_lib {
609705 return set_rows_pipelines[key];
610706 }
611707
708+ webgpu_pipeline get_set_pipeline (const ggml_webgpu_shader_lib_context & context) {
709+ ggml_webgpu_set_pipeline_key key = { .type = context.dst ->type , .inplace = context.inplace };
710+
711+ auto it = set_pipelines.find (key);
712+ if (it != set_pipelines.end ()) {
713+ return it->second ;
714+ }
715+
716+ std::vector<std::string> defines;
717+ std::string variant = " set" ;
718+
719+ switch (key.type ) {
720+ case GGML_TYPE_F32:
721+ defines.push_back (" TYPE_F32" );
722+ variant += " _f32" ;
723+ break ;
724+ case GGML_TYPE_I32:
725+ defines.push_back (" TYPE_I32" );
726+ variant += " _i32" ;
727+ break ;
728+ default :
729+ GGML_ABORT (" Unsupported type for set shader" );
730+ }
731+
732+ if (key.inplace ) {
733+ defines.push_back (" INPLACE" );
734+ variant += " _inplace" ;
735+ }
736+
737+ defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (context.max_wg_size ));
738+
739+ auto processed = preprocessor.preprocess (wgsl_set, defines);
740+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
741+ decisions->wg_size = context.max_wg_size ;
742+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed, variant);
743+ pipeline.context = decisions;
744+ set_pipelines[key] = pipeline;
745+ return set_pipelines[key];
746+ }
747+
612748 webgpu_pipeline get_cumsum_pipeline (const ggml_webgpu_shader_lib_context & context) {
613749 auto it = cumsum_pipelines.find (1 );
614750 if (it != cumsum_pipelines.end ()) {
@@ -695,6 +831,7 @@ class ggml_webgpu_shader_lib {
695831
696832 switch (key.src_type ) {
697833 case GGML_TYPE_F32:
834+ defines.push_back (" FLOAT_PARALLEL" );
698835 if (key.vectorized ) {
699836 defines.push_back (" F32_VEC" );
700837 defines.push_back (" SRC_TYPE=vec4<f32>" );
@@ -709,13 +846,15 @@ class ggml_webgpu_shader_lib {
709846 variant += " _f32" ;
710847 break ;
711848 case GGML_TYPE_F16:
849+ defines.push_back (" FLOAT_PARALLEL" );
712850 defines.push_back (" F16" );
713851 defines.push_back (" SRC_TYPE=f16" );
714852 defines.push_back (" DST_TYPE=f32" );
715853 defines.push_back (" BLOCK_SIZE=1u" );
716854 variant += " _f16" ;
717855 break ;
718856 case GGML_TYPE_I32:
857+ defines.push_back (" FLOAT_PARALLEL" );
719858 defines.push_back (" I32" );
720859 defines.push_back (" SRC_TYPE=i32" );
721860 defines.push_back (" DST_TYPE=i32" );
@@ -794,6 +933,128 @@ class ggml_webgpu_shader_lib {
794933 return scale_pipelines[key];
795934 }
796935
936+ webgpu_pipeline get_solve_tri_pipeline (const ggml_webgpu_shader_lib_context & context) {
937+ ggml_webgpu_solve_tri_pipeline_key key = {
938+ .type = context.dst ->type ,
939+ .n = (int ) context.src0 ->ne [0 ],
940+ .k = (int ) context.src1 ->ne [0 ],
941+ };
942+
943+ auto it = solve_tri_pipelines.find (key);
944+ if (it != solve_tri_pipelines.end ()) {
945+ return it->second ;
946+ }
947+
948+ std::vector<std::string> defines;
949+ std::string variant = " solve_tri" ;
950+
951+ switch (key.type ) {
952+ case GGML_TYPE_F32:
953+ variant += " _f32" ;
954+ break ;
955+ default :
956+ GGML_ABORT (" Unsupported type for solve_tri shader" );
957+ }
958+
959+ const uint32_t wg_size = std::min ((uint32_t ) key.n , context.max_wg_size );
960+ const uint32_t k_tile = wg_size;
961+ const uint32_t bytes_per_row = ((uint32_t ) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
962+ const uint32_t batch_n = (uint32_t ) (context.wg_mem_limit_bytes / bytes_per_row);
963+
964+ defines.push_back (std::string (" N=" ) + std::to_string (key.n ));
965+ defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (wg_size));
966+ defines.push_back (std::string (" K_TILE=" ) + std::to_string (k_tile));
967+ defines.push_back (std::string (" BATCH_N=" ) + std::to_string (batch_n));
968+
969+ auto processed = preprocessor.preprocess (wgsl_solve_tri, defines);
970+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
971+ decisions->wg_size = wg_size;
972+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed, variant);
973+ pipeline.context = decisions;
974+ solve_tri_pipelines[key] = pipeline;
975+ return solve_tri_pipelines[key];
976+ }
977+
978+ webgpu_pipeline get_ssm_conv_pipeline (const ggml_webgpu_shader_lib_context & context) {
979+ ggml_webgpu_ssm_conv_pipeline_key key = {
980+ .type = context.dst ->type ,
981+ .vectorized = context.src1 ->ne [0 ] == 4 ,
982+ };
983+
984+ auto it = ssm_conv_pipelines.find (key);
985+ if (it != ssm_conv_pipelines.end ()) {
986+ return it->second ;
987+ }
988+
989+ std::vector<std::string> defines;
990+ std::string variant = " ssm_conv" ;
991+
992+ switch (key.type ) {
993+ case GGML_TYPE_F32:
994+ variant += " _f32" ;
995+ break ;
996+ default :
997+ GGML_ABORT (" Unsupported type for ssm_conv shader" );
998+ }
999+
1000+ if (key.vectorized ) {
1001+ defines.push_back (" VECTORIZED" );
1002+ variant += " _vec4" ;
1003+ }
1004+
1005+ constexpr uint32_t block_size = 32u ;
1006+ constexpr uint32_t tokens_per_wg = 8u ;
1007+
1008+ defines.push_back (" BLOCK_SIZE=" + std::to_string (block_size) + " u" );
1009+ defines.push_back (" TOKENS_PER_WG=" + std::to_string (tokens_per_wg) + " u" );
1010+
1011+ auto processed = preprocessor.preprocess (wgsl_ssm_conv, defines);
1012+ auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
1013+ decisions->block_size = block_size;
1014+ decisions->tokens_per_wg = tokens_per_wg;
1015+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed, variant);
1016+ pipeline.context = decisions;
1017+ ssm_conv_pipelines[key] = pipeline;
1018+ return ssm_conv_pipelines[key];
1019+ }
1020+
1021+ webgpu_pipeline get_gated_delta_net_pipeline (const ggml_webgpu_shader_lib_context & context) {
1022+ ggml_webgpu_gated_delta_net_pipeline_key key = {
1023+ .type = context.dst ->type ,
1024+ .s_v = (int ) context.src2 ->ne [0 ],
1025+ .kda = context.src3 ->ne [0 ] == context.src2 ->ne [0 ],
1026+ };
1027+
1028+ auto it = gated_delta_net_pipelines.find (key);
1029+ if (it != gated_delta_net_pipelines.end ()) {
1030+ return it->second ;
1031+ }
1032+
1033+ std::vector<std::string> defines;
1034+ std::string variant = " gated_delta_net" ;
1035+
1036+ switch (key.type ) {
1037+ case GGML_TYPE_F32:
1038+ variant += " _f32" ;
1039+ break ;
1040+ default :
1041+ GGML_ABORT (" Unsupported type for gated_delta_net shader" );
1042+ }
1043+
1044+ if (key.kda ) {
1045+ defines.push_back (" KDA" );
1046+ variant += " _kda" ;
1047+ }
1048+
1049+ defines.push_back (" S_V=" + std::to_string (key.s_v ) + " u" );
1050+ defines.push_back (" WG_SIZE=" + std::to_string (key.s_v ) + " u" );
1051+
1052+ auto processed = preprocessor.preprocess (wgsl_gated_delta_net, defines);
1053+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed, variant);
1054+ gated_delta_net_pipelines[key] = pipeline;
1055+ return gated_delta_net_pipelines[key];
1056+ }
1057+
7971058 webgpu_pipeline get_pad_pipeline (const ggml_webgpu_shader_lib_context & context) {
7981059 ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32 (context.dst , 8 ) != 0 };
7991060
0 commit comments