@@ -11638,6 +11638,21 @@ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, fl
1163811638 return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
1163911639}
1164011640
11641+ static void ggml_rope_cache_init(
11642+ float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
11643+ float * cache, float sin_sign, float theta_scale
11644+ ) {
11645+ float theta = theta_base;
11646+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11647+ rope_yarn(
11648+ theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
11649+ );
11650+ cache[i0 + 1] *= sin_sign;
11651+
11652+ theta *= theta_scale;
11653+ }
11654+ }
11655+
1164111656void ggml_rope_yarn_corr_dims(
1164211657 int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1164311658) {
@@ -11720,6 +11735,12 @@ static void ggml_compute_forward_rope_f32(
1172011735 for (int64_t i3 = 0; i3 < ne3; i3++) {
1172111736 for (int64_t i2 = 0; i2 < ne2; i2++) {
1172211737 const int64_t p = pos[i2];
11738+
11739+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11740+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11741+ ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11742+ }
11743+
1172311744 for (int64_t i1 = 0; i1 < ne1; i1++) {
1172411745 if (ir++ < ir0) continue;
1172511746 if (ir > ir1) break;
@@ -11753,18 +11774,13 @@ static void ggml_compute_forward_rope_f32(
1175311774 }
1175411775 } else if (!is_neox) {
1175511776 for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11756- float cos_theta, sin_theta;
11757- rope_yarn(
11758- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11759- );
11760- sin_theta *= sin_sign;
11777+ const float cos_theta = cache[i0 + 0];
11778+ const float sin_theta = cache[i0 + 1];
1176111779
1176211780 // zeta scaling for xPos only:
1176311781 float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
1176411782 if (xpos_down) zeta = 1.0f / zeta;
1176511783
11766- theta_base *= theta_scale;
11767-
1176811784 const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1176911785 float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1177011786
@@ -11888,6 +11904,12 @@ static void ggml_compute_forward_rope_f16(
1188811904 for (int64_t i3 = 0; i3 < ne3; i3++) {
1188911905 for (int64_t i2 = 0; i2 < ne2; i2++) {
1189011906 const int64_t p = pos[i2];
11907+
11908+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11909+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11910+ ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11911+ }
11912+
1189111913 for (int64_t i1 = 0; i1 < ne1; i1++) {
1189211914 if (ir++ < ir0) continue;
1189311915 if (ir > ir1) break;
@@ -11921,13 +11943,8 @@ static void ggml_compute_forward_rope_f16(
1192111943 }
1192211944 } else if (!is_neox) {
1192311945 for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11924- float cos_theta, sin_theta;
11925- rope_yarn(
11926- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11927- );
11928- sin_theta *= sin_sign;
11929-
11930- theta_base *= theta_scale;
11946+ const float cos_theta = cache[i0 + 0];
11947+ const float sin_theta = cache[i0 + 1];
1193111948
1193211949 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1193311950 ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -16722,6 +16739,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
1672216739 }
1672316740 } break;
1672416741 case GGML_OP_SOFT_MAX:
16742+ case GGML_OP_ROPE:
1672516743 {
1672616744 cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1672716745 } break;
0 commit comments