@@ -43,6 +43,7 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
4343}
4444
4545
46+ #include < cutlass/version.h>
4647#include < cutlass/core_io.h>
4748#include < cutlass/cutlass.h>
4849#include < cutlass/gemm/device/gemm.h>
@@ -174,7 +175,11 @@ void f8f8bf16_rowwise_impl(
174175
175176 // Implement rowwise scaling epilogue.
176177 constexpr int ColBroadcastStages = 0 ;
178+ #if CUTLASS_VERSION == 351
179+ constexpr int RowBroadcastStages = 0 ;
180+ #else
177181 constexpr int RowBroadcastStages = PingPong::value ? 2 : 1 ;
182+ #endif
178183
179184 using XScale = cutlass::epilogue::fusion::
180185 Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
@@ -191,15 +196,24 @@ void f8f8bf16_rowwise_impl(
191196
192197 using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
193198
199+ #if CUTLASS_VERSION == 351
200+ using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
201+ Multiply,
202+ WScale,
203+ cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
204+ #else
205+ using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
206+ Multiply,
207+ XScale,
208+ cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>;
209+ #endif
210+
194211 using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
195212 Cast,
196213 cutlass::epilogue::fusion::Sm90EVT<
197214 Add,
198215 Bias,
199- cutlass::epilogue::fusion::Sm90EVT<
200- Multiply,
201- XScale,
202- cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>>>;
216+ AccumScale>>;
203217
204218 using CollectiveEpilogue =
205219 typename cutlass::epilogue::collective::CollectiveBuilder<
0 commit comments