Skip to content

Conversation

@pwilkin
Copy link
Collaborator

@pwilkin pwilkin commented Nov 23, 2025

I've managed to actually poke enough LLMs in the correct direction to end up with this:

CPU:

  SOLVE_TRI(type=f32,ne_lhs=[64,64,4,2],ne_rhs=[6,64,4,2]):                    32760 runs -    37.34 us/run -      152 kB/run -    3.88 GB/s

CUDA:

  SOLVE_TRI(type=f32,ne_lhs=[64,64,4,2],ne_rhs=[6,64,4,2]):                    49140 runs -    22.52 us/run -      152 kB/run -    6.44 GB/s

This can most certainly be improved by someone who knows what they're doing, but at least it does the bare minimum by supplying a CUDA kernel that's around twice as fast as the optimized CPU implementation.

@pwilkin pwilkin requested a review from slaren as a code owner November 23, 2025 22:55
@pwilkin pwilkin mentioned this pull request Nov 23, 2025
@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 23, 2025
@wsbagnsv1
Copy link
Contributor

wsbagnsv1 commented Nov 24, 2025

Hey, I've created a small framework for OpenEvolve for this kernel and ran it for about 40 iterations (I plan to do around 600) and already got around 8% improvement on the kernel below. I'm pretty sure I've covered all test cases, but you should check it for correctness if I missed something in my test cases. Anyway, here is the performance improvement on my old RTX 2070 for the kernel:

Performance improvement

Oh and this was reproducible with multiple runs over time and showed a consistent improvement (;

This is the kernel:

#include <cuda_fp16.h>

#define MAX_N_FAST 64
#define MAX_K_FAST 32
#define WARP_SIZE 32

// Warp reduction helper with full mask for safety
static __inline__ __device__ float warpReduceSum(float val) {
    // Use full mask for all participating threads
    unsigned mask = __activemask();
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        val += __shfl_down_sync(mask, val, offset);
    }
    return val;
}

// Optimized kernel focusing on coalesced access and warp-level parallelism
extern "C" __global__ void solve_tri_f32_fast(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ X,
    int n, int k,
    int64_t ne02, int64_t ne03,
    size_t nb02, size_t nb03,
    size_t nb12, size_t nb13,
    size_t nb2, size_t nb3)
{
    const int batch_idx = blockIdx.x;
    const int lane      = threadIdx.x;
    const int col_idx   = threadIdx.y;
    const int tid       = threadIdx.x + threadIdx.y * blockDim.x;

    // Early exit for excess warps
    if (col_idx >= k) {
        return;
    }

    // Calculate batch indices
    const int64_t i03 = batch_idx / ne02;
    const int64_t i02 = batch_idx % ne02;

    // Get pointers for this batch
    const float* const A_batch = (const float*)((const char *)A + i02 * nb02 + i03 * nb03);
    const float* const B_batch = (const float*)((const char *)B + i02 * nb12 + i03 * nb13);
    float*             X_batch = (float*)      ((char *)X + i02 * nb2  + i03 * nb3);

    // Shared memory for A and B matrices
    __shared__ float sA[MAX_N_FAST * MAX_N_FAST];
    __shared__ float sX[MAX_N_FAST * MAX_K_FAST];

    // Coalesced loading of A matrix
    // Each thread loads multiple elements to improve bandwidth utilization
    const int total_elements_A = n * n;
    const int stride_A = blockDim.x * blockDim.y;
    for (int i = tid; i < total_elements_A; i += stride_A) {
        sA[i] = A_batch[i];
    }

    // Coalesced loading of B matrix  
    const int total_elements_B = n * k;
    const int stride_B = blockDim.x * blockDim.y;
    for (int i = tid; i < total_elements_B; i += stride_B) {
        sX[i] = B_batch[i];
    }
    __syncthreads();

    // Forward substitution with warp-level parallelism
    // Each warp processes one column of the solution
    for (int row = 0; row < n; ++row) {
        float sum = 0.0f;

        // Use register accumulation for better ILP
        float sum_part = 0.0f;
        
        // Parallel computation of dot product
        // Each thread in the warp processes a subset of elements
        for (int j = lane; j < row; j += WARP_SIZE) {
            // Load row of A and column of X from shared memory
            sum_part += sA[row * n + j] * sX[j * k + col_idx];
        }

        // Warp-level reduction to get the final sum
        sum = warpReduceSum(sum_part);

        // Lane 0 performs the final computation and stores result
        if (lane == 0) {
            const float b_val = sX[row * k + col_idx];
            const float a_diag = sA[row * n + row];
            
            // The safe exact check:
            if (a_diag != 0.0f) {
                sX[row * k + col_idx] = (b_val - sum) / a_diag;
            } else {
                sX[row * k + col_idx] = 0.0f; // Only catch true division by zero
            }
        }
        
        // Synchronize to ensure all threads see the updated value
        __syncthreads();
    }

    // Coalesced write back to global memory
    for (int i = tid; i < total_elements_B; i += stride_B) {
        X_batch[i] = sX[i];
    }
}

For those curious, OpenEvolve implements the AlphaEvolve approach from Google (or at least it started as that): using LLMs to iteratively evolve and optimize algorithms.

Also ill upload the framework tomorrow to my github for anyone interested (;

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 24, 2025

@am17an something like this?

@am17an
Copy link
Collaborator

am17an commented Nov 24, 2025

@pwilkin not quite, something like this https://github.com/pwilkin/llama.cpp/compare/solve_tri_cuda...am17an:llama.cpp:solve_tri_cuda_opt?expand=1

With this I get

SOLVE_TRI(type=f32,ne_lhs=[64,64,4,2],ne_rhs=[6,64,4,2]):                    57330 runs -    19.13 us/run -      152 kB/run -    7.58 GB/s

@am17an
Copy link
Collaborator

am17an commented Nov 24, 2025

One other thing, you don't need an entirely separate function for the general case, you can pass 0 as the template parameter and do an if constexpr on the unrolled parts

__shared__ float sX[MAX_N_FAST * MAX_K_FAST];

// Load A into shared memory (coalesced)
#pragma unroll
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These cannot be unrolled so you can remove

Suggested change
#pragma unroll

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole function can go away, it should be something like

if constexpr(n == 0) { 
   //take this path
} else {
  #pragma unroll 
  //the fast loop
}

__shared__ float sX[MAX_N_FAST * MAX_K_FAST];

// Load A into shared memory (coalesced)
#pragma unroll
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole function can go away, it should be something like

if constexpr(n == 0) { 
   //take this path
} else {
  #pragma unroll 
  //the fast loop
}

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 24, 2025

@am17an not very experienced with this, but I believe this is what you had in mind?

Copy link
Collaborator

@am17an am17an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah looks good. You can run it through clang-format once. Also let's get @JohannesGaessler to take a look as well, he usually has the best ideas re performance

@theo77186
Copy link
Contributor

theo77186 commented Nov 24, 2025

For some reason, when testing with test-backend-ops, this case fails: SOLVE_TRI(type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]), with wildly different NMSE (can be from 10^-4 to very large values). It may even occasionally pass. It only fails on my 4060Ti (sm89) but not on my 3060 (sm86). I don't see any reason the kernel would behave differently, though.

logs

[SOLVE_TRI] NMSE = 16.742420406 > 0.000000100   SOLVE_TRI(type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]): FAIL

edit: would appreciate if anyone with a sm89 GPU could reproduce this

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 24, 2025

For some reason, when testing with test-backend-ops, this case fails: SOLVE_TRI(type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]), with wildly different NMSE (can be from 10^-4 to very large values). It may even occasionally pass. It only fails on my 4060Ti (sm89) but not on my 3060 (sm86). I don't see any reason the kernel would behave differently, though.

logs

[SOLVE_TRI] NMSE = 16.742420406 > 0.000000100   SOLVE_TRI(type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]): FAIL

edit: would appreciate if anyone with a sm89 GPU could reproduce this

Can confirm. My 3080 (which is my CUDA0) works correctly, but my 5060 fails with the pattern you described. @am17an any ideas? Looks like some race condition, but why only on 89+ arch cards?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 24, 2025

Never mind; needed to add a guard, apparently for 30x0 it didn't mind.

Comment on lines 42 to 44
const float * const A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally speaking it is preferable to pass the strides in units of float instead of char.

@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_SOLVE_TRI_BLOCK_SIZE 256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel does not respect this define. More generally, you are launching a kernel with up to 32*32=1024 threads which is in principle still possible but becomes problematic in terms of register pressure. My recommendation would be to launch at most 256 threads, to specify this upper limit via __launch_bounds__, and to handle the cases which currently use > 8 warps with a loop.

@JohannesGaessler
Copy link
Collaborator

Never mind; needed to add a guard, apparently for 30x0 it didn't mind.

I think the reason you needed to add a guard is because the number of warps can be != a power of 2.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 27, 2025

@wsbagnsv1 btw, care to upload your openevolve evaluator somewhere?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 27, 2025

@wsbagnsv1 BTW it's always good to look at the generated kernels as well :)

Your kernel has completely unneeded writes to an extra array - you only write to sXt and then back to X_batch directly (transpose on the fly), you don't need sX under that approach at all.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 27, 2025

@JohannesGaessler @am17an I think this is as far as I can push this, I believe it's already pretty well optimized.

@am17an
Copy link
Collaborator

am17an commented Nov 27, 2025

Nice! I think you need another clang-format and preferably pass the stride in units of float rather than bytes

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 27, 2025

Cleaned it up. Final results:

  Device description: NVIDIA GeForce RTX 3080
  Device memory: 9871 MB (15969 MB free)

  SOLVE_TRI(type=f32,ne_lhs=[64,64,4,2],ne_rhs=[6,64,4,2]):                   106470 runs -     9.51 us/run -      152 kB/run -   15.24 GB/s
  
  Device description: NVIDIA GeForce RTX 5060 Ti
  Device memory: 15848 MB (15958 MB free)

  SOLVE_TRI(type=f32,ne_lhs=[64,64,4,2],ne_rhs=[6,64,4,2]):                   147420 runs -     7.11 us/run -      152 kB/run -   20.38 GB/s

Copy link
Collaborator

@am17an am17an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a new nitpicks, can merge post that

@wsbagnsv
Copy link

wsbagnsv commented Nov 27, 2025

@wsbagnsv1 btw, care to upload your openevolve evaluator somewhere?

Yeah have to clean it up a bit and I'll put it on my github, had to hardcode the paths for the compilation etc 😅
It's not that hard to run, basically it compiles Llama.cpp each time with the new kernel and runs perf as well as nsight on it (; aim to do that either today or tomorrow (;

@wsbagnsv
Copy link

wsbagnsv commented Nov 27, 2025

@wsbagnsv1 BTW it's always good to look at the generated kernels as well :)

Your kernel has completely unneeded writes to an extra array - you only write to sXt and then back to X_batch directly (transpose on the fly), you don't need sX under that approach at all.

Probably because it was a rather early iteration, ill check the later ones once I'm back on my pc but yeah you're right didn't have much time left before sleep and saw that massive improvement and just copy pasted it basically so its at least online 😅

@jeffbolznv
Copy link
Collaborator

Perf on my system:

Backend 1/3: CUDA0
  Device description: NVIDIA GeForce RTX 5090
  Device memory: 32606 MB (30991 MB free)

  SOLVE_TRI(type=f32,ne_lhs=[64,64,4,2],ne_rhs=[6,64,4,2]):                   139230 runs -     7.20 us/run -      152 kB/run -   20.15 GB/s
  SOLVE_TRI(type=f32,ne_lhs=[128,128,4,1],ne_rhs=[8,128,4,1]): not supported
  Backend CUDA0: OK
Backend 2/3: CUDA1
  Device description: NVIDIA GeForce RTX 4070
  Device memory: 12281 MB (11106 MB free)

  SOLVE_TRI(type=f32,ne_lhs=[64,64,4,2],ne_rhs=[6,64,4,2]):                   147420 runs -     6.91 us/run -      152 kB/run -   20.99 GB/s
  SOLVE_TRI(type=f32,ne_lhs=[128,128,4,1],ne_rhs=[8,128,4,1]): not supported

@theo77186
Copy link
Contributor

my own results

4060Ti and 3060 performance
ggml_backend_cuda_get_available_uma_memory: final available_memory_kb: 23731724
Backend 1/4: CUDA0
  Device description: NVIDIA GeForce RTX 4060 Ti
  Device memory: 15982 MB (23175 MB free)

  SOLVE_TRI(type=f32,ne_lhs=[64,64,4,2],ne_rhs=[6,64,4,2]):                   155610 runs -     6.67 us/run -      152 kB/run -   21.73 GB/s
  SOLVE_TRI(type=f32,ne_lhs=[128,128,4,1],ne_rhs=[8,128,4,1]): not supported
  Backend CUDA0: OK
ggml_backend_cuda_get_available_uma_memory: final available_memory_kb: 23597052
Backend 2/4: CUDA1
  Device description: NVIDIA GeForce RTX 3060
  Device memory: 11947 MB (23043 MB free)

  SOLVE_TRI(type=f32,ne_lhs=[64,64,4,2],ne_rhs=[6,64,4,2]):                   122850 runs -     8.69 us/run -      152 kB/run -   16.69 GB/s
  SOLVE_TRI(type=f32,ne_lhs=[128,128,4,1],ne_rhs=[8,128,4,1]): not supported
  Backend CUDA1: OK

Curiously, it seems there isn't much scaling within a given GPU family...

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 27, 2025

@theo77186 it's a small kernel, doesn't utilize all the cores, thus no scaling. Probably would scale more on a tiled version.

@wsbagnsv1
Copy link
Contributor

wsbagnsv1 commented Nov 27, 2025

@theo77186 it's a small kernel, doesn't utilize all the cores, thus no scaling. Probably would scale more on a tiled version.

Ive just created my fork of openevolve and uploaded my changes with the solve_tri kernel, maybe it can be used to find a proper optimized tiled version of this kernel too if we even need it, though wed have to change the parser in the evaluator and the config.yaml. But thats for another pr i guess 😉
https://github.com/wsbagnsv1/openevolve-cuda-trisolve

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 27, 2025

@am17an Aight, CI tests are fine, think we can merge.

@am17an am17an merged commit cd0e3a7 into ggml-org:master Nov 28, 2025
72 of 74 checks passed
float * X_batch = (float *) (X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float));

__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply changing the size of a 1D allocation like this does nothing to fix shared memory bank conflicts. You have to actually access elements with the padded stride. One way to do this automatically is to change the array shape to be 2D and to pad the last dimension.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler You're right, I didn't think this one through :) will try to fix it and submit a separate PR.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pwilkin could you link it here? Thanks

@JohannesGaessler
Copy link
Collaborator

Sorry, I forgot to press the submit button on my review. It is one more optimization that could be done, though for any optimizations you should first check what percentage of the runtime the operation takes up in the first place (because that is the maximum percentage that you can shave off).

@wsbagnsv1
Copy link
Contributor

Sorry, I forgot to press the submit button on my review. It is one more optimization that could be done, though for any optimizations you should first check what percentage of the runtime the operation takes up in the first place (because that is the maximum percentage that you can shave off).

I let my open evolve framework run more iterations and am now on a kernel that has 18%(2070) - 23%(4070ti) speedup in bandwidth over the current kernel. If i measured it correctly it seems the math runs around 36x per encode operation for 36 active layers. Atm this seems to not really have much impact on encode time, so when would it even make sense to create a pr?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 28, 2025

@wsbagnsv1 of course, create a PR and don't worry about the impact, if you can show a measurable increase in bandwidth then it's worth uploading.

@wsbagnsv1
Copy link
Contributor

okay then ill create a draft pr, since I didnt max it out yet or just create it in a few days (;

@wsbagnsv1
Copy link
Contributor

btw i have updated the framework on my github, so it should now be a lot simple to use this framework for other kernels and it should also be a lot more efficient (;

@JohannesGaessler
Copy link
Collaborator

If there is no meaningful difference for end-to-end performance we should not be merging changes that come at the cost of maintainability though.

@wsbagnsv1
Copy link
Contributor

wsbagnsv1 commented Nov 28, 2025

If there is no meaningful difference for end-to-end performance we should not be merging changes that come at the cost of maintainability though.

I mean i can show you the changes, it stays roughly the same number of lines, but there are still quite a few, would such a change be accepted with clear performance gain in the kernel itself but not in end to end performance?
Would be interesting to know even outside of this specific kernel since i plan to at least try to optimize basically all cuda and maybe vulcan kernels in the future with this method (;
(Since this is not the end result there might be redundant stuff in it that can be further improve/cleaned up)

@@ -3,7 +3,6 @@
 #include "solve_tri.cuh"
 
 #define MAX_N_FAST 64
-#define MAX_K_FAST 32
 
 // ======================
 // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
@@ -14,7 +13,7 @@
 #ifdef __clang__
 #    pragma clang diagnostic push
 #    pragma clang diagnostic ignored "-Wpass-failed"
-#endif  // __clang__
+#endif // __clang__*
 template <int n_template, int k_template>
 static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
                                           const float * __restrict__ B,
@@ -43,12 +42,11 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
     const int64_t i02     = i02_i03.y;
     const int64_t i03     = i02_i03.x;
 
-    const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
-    const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
-    float *             X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
+    const float * const A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03);
+    const float * const B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13);
+    float *             X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3);
 
     __shared__ float sA[MAX_N_FAST * MAX_N_FAST];
-    __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
 
     const int offset = threadIdx.x + threadIdx.y * blockDim.x;
 
@@ -60,53 +58,60 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
         }
     }
 
-    const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
+    __syncthreads();
 
-#pragma unroll
-    for (int i = 0; i < rows_per_warp; i++) {
-        const int i0 = lane + i * WARP_SIZE;
-        if (i0 < n) {
-            sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
-        }
-    }
+    float x_low  = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
+    float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
 
-    __syncthreads();
+    const int half = WARP_SIZE;
+    const int nrows_low = (n < half) ? n : half;
 
+    // Process lower rows
 #pragma unroll
-    for (int row = 0; row < n; ++row) {
+    for (int row = 0; row < nrows_low; ++row) {
         float sum = 0.0f;
-
-        {
-            int j = lane;
-            if (j < row) {
-                sum += sA[row * n + j] * sXt[col_idx * n + j];
-            }
+        if (lane < row) {
+            sum = fmaf(sA[row * n + lane], x_low, sum);
         }
-        if (row >= WARP_SIZE) {
-            int j = WARP_SIZE + lane;
-            if (j < row) {
-                sum += sA[row * n + j] * sXt[col_idx * n + j];
-            }
+        sum = warp_reduce_sum(sum);
+
+        float diag = sA[row * n + row];
+        float idiv = 1.0f / diag;
+        float b_val = __shfl_sync(0xffffffffu, x_low, row);
+        float new_x = fmaf(sum, -idiv, b_val * idiv);
+
+        if (lane == row) {
+            x_low = new_x;
         }
+    }
 
+    // Process upper rows
+#pragma unroll
+    for (int row = half; row < n; ++row) {
+        float sum = fmaf(sA[row * n + lane], x_low, 0.0f);
+        int j = half + lane;
+        if (j < row) {
+            sum = fmaf(sA[row * n + j], x_high, sum);
+        }
         sum = warp_reduce_sum(sum);
 
-        if (lane == 0) {
-            const float b_val      = sXt[col_idx * n + row];
-            const float a_diag     = sA[row * n + row];
-            // no safeguards for division by zero because that indicates corrupt
-            // data anyway
-            sXt[col_idx * n + row] = (b_val - sum) / a_diag;
+        float diag = sA[row * n + row];
+        float idiv = 1.0f / diag;
+        float b_val = __shfl_sync(0xffffffffu, x_high, row - half);
+        float new_x = fmaf(sum, -idiv, b_val * idiv);
+
+        if (lane == row - half) {
+            x_high = new_x;
         }
     }
 
-    __syncthreads();
-
-#pragma unroll
-    for (int i = 0; i < rows_per_warp; i++) {
-        const int i0 = lane + i * WARP_SIZE;
-        if (i0 < n) {
-            X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
+    // Warp-wise store
+#pragma unroll 2
+    for (int rr = 0; rr < 2; ++rr) {
+        int row = rr * WARP_SIZE + lane;
+        if (row < n) {
+            float val = (row < half) ? x_low : x_high;
+            X_batch[row * k + col_idx] = val;
         }
     }
 }
@@ -197,7 +202,6 @@ void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     GGML_ASSERT(k <= 32);
 
     solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
-                       src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
-                       src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
-                       dst->nb[3] / sizeof(float), ctx.stream());
+                       src0->ne[3], src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], dst->nb[2], dst->nb[3],
+                       ctx.stream());
 }

The framework ive used based on open evolve: https://github.com/wsbagnsv1/openevolve-cuda-trisolve

@JohannesGaessler
Copy link
Collaborator

Changes like that would be acceptable.

@wsbagnsv1
Copy link
Contributor

Changes like that would be acceptable.

well then ill let it continue to optimize that over night and create a pr tomorrow ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants