-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtest.py
More file actions
276 lines (244 loc) · 7.96 KB
/
test.py
File metadata and controls
276 lines (244 loc) · 7.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import os
from copy import deepcopy
import nvtx
import time
import math
import pytest
import torch
import numpy as np
from einops import repeat
import triton
try:
from vllm.vllm_flash_attn import (
flash_attn_with_kvcache as vllm_flash_attn_with_kvcache,
)
except ImportError:
print("[WARN] vllm_flash_attn_with_kvcache not found, skip related tests.")
vllm_flash_attn_with_kvcache = None
from prefix_attn import (
prefix_attn_with_kvcache,
PackedBox,
Sequence,
SeqGroup,
pack_prefix_blocks,
pack_without_prefix,
KernelInfo,
)
from prefix_attn import (
generate_random_prefix_seqs,
generate_random_kv_cache,
generate_tree_seqs,
)
from prefix_attn.calc_theoretical import calc_theoretical
from prefix_attn.data_class import create_seq_group
from prefix_attn.block_scheduler import schedule, schedule_naive
from prefix_attn.utils import make_tensor_with_pad
from prefix_attn.prefix_tree import PrefixTree
from utils import CPUTimer, attention_ref, GPUTimer, Timer
from prefix_attn import PrefixTreeCPP
def tree_benchmark(
seq_group: SeqGroup,
num_blocks: int,
nheads_q, # [attn]
nheads_kv, # [attn]
head_dim, # [attn]
block_size, # [page]
n_repeats,
dtype,
device,
seed,
baselines: list = None,
enable_timing: bool = False,
):
torch.cuda.set_device(device)
torch.random.manual_seed(seed)
assert head_dim in [64, 128], "head_dim should be 64 or 128 currently"
assert nheads_q % nheads_kv == 0
assert enable_timing is False, "timing is deprecated"
scale = 1.0 / math.sqrt(head_dim)
max_error = 1e-2
# ----- step-1: generate random q k_cache v_cache tensors -------------------------------------------------
q = torch.randn(len(seq_group), 1, nheads_q, head_dim, device=device, dtype=dtype)
(
seqlens,
k_cache,
v_cache,
kv_padding_mask,
block_table,
k_cache_paged,
v_cache_paged,
) = generate_random_kv_cache(
seq_group=seq_group,
num_blocks=num_blocks,
block_size=block_size,
nheads_kv=nheads_kv,
head_dim=head_dim,
dtype=dtype,
device=device,
seed=int(time.time()),
)
# print(f"seqlens: {seqlens}, k_cache: {k_cache.shape}, v_cache: {v_cache.shape}")
# print(f"table: {block_table.shape}, k_cache_paged: {k_cache_paged.shape}, v_cache_paged: {v_cache_paged.shape}")
# check table is unpad
for i, row in enumerate(block_table.cpu().tolist()):
if row and len(row) > 1 and row[-1] == 0:
raise ValueError(
f"Table is expected to be unpadded, but row {i} has padding 0 at the end."
)
# ----- step-2: schedule and transfer results to GPU (count to overhead) ----------------------------------
theoretical_time = calc_theoretical(block_table, nheads_q, nheads_kv, head_dim)
# print(f"[INFO] (pytest) (attn) approximate theoretical time: {theoretical_time:.4f}ms")
_start = time.perf_counter()
MNWs = None
table = block_table.cpu()
_start_build = time.perf_counter()
seq_lens = seq_group.seqlens
tree = PrefixTreeCPP(block_size)
tree.build_radix_tree(seq_lens, table)
tree.pack_schedule(MNWs, nheads_q // nheads_kv, nheads_kv)
tree.kernel_info.to_gpu(torch.device(device))
# ----- step-3: run and compare results -------------------------------------------------------------------
# prefix-attention
latencies = ""
out = torch.empty_like(q, device=device, dtype=dtype)
def pa_func():
# print(q.shape, k_cache_paged.shape, v_cache_paged.shape, out.shape)
prefix_attn_with_kvcache(
q=q,
k_cache_paged=k_cache_paged,
v_cache_paged=v_cache_paged,
tree=tree,
softmax_scale=scale,
out=out,
)
latencies += f"[INFO] kernel latency: PAT={Timer(pa_func, n_repeats):.3f}ms"
# vllm-flash-attention
if (
vllm_flash_attn_with_kvcache is not None
and baselines is not None
and ("all" in baselines or "vllm-fa" in baselines)
):
def flash_func():
_ = vllm_flash_attn_with_kvcache(
q,
k_cache_paged,
v_cache_paged,
cache_seqlens=seqlens,
block_table=block_table,
causal=True,
num_splits=0, # 0: best split; 1: do not split kv
)
out_flash = vllm_flash_attn_with_kvcache(
q,
k_cache_paged,
v_cache_paged,
cache_seqlens=seqlens,
block_table=block_table,
causal=True,
num_splits=0, # 0: best split; 1: do not split kv
)
print(
f"[DEBUG] Max diff (PAT - FlashAttention): {(out - out_flash).abs().max().item()}"
)
print(
f"[DEBUG] Mean diff (PAT - FlashAttention): {(out - out_flash).abs().mean().item()}"
)
latencies += f" | FlashAttention={Timer(flash_func, n_repeats):.3f}ms"
assert (out - out_flash).abs().max().item() <= max_error, "out_flash error"
print(latencies)
print("[INFO] successfully pass the test!")
@pytest.mark.parametrize(
"tree",
[
"1,2,4,16,32,64,128,256,512_3072,1024,32,32,32,32,32,32,32",
"1,2,4,16,32,64,256_4096,32,128,128,32,256,512",
"1,2,4,16,32,64,128,256,512_3072,1024,512,256,128,32,256,512,32",
"1,2,4,16,32,64,128,256,1024_512,32,128,128,32,256,512,32,32",
"1,2,4,16,32,64,128,512_512,32,128,128,32,256,512,256",
"1,2,4,16,32,64,128_512,32,128,128,32,256,512",
"4_4096",
"4,16_4096,256",
"1,8_640,32",
"1,4,16_32,256,32",
"1,2,4,8,16,32,64_32,512,32,256,32,32,32",
"1,4,256_32,256,32",
"1,16,32_2048,1024,256",
"1,10_4096,416",
],
)
@pytest.mark.parametrize("nheads_q,nheads_kv", ([64, 8],))
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("block_size", [32])
@pytest.mark.parametrize("n_repeats", [20])
def test_tree_attn_kvcache(
nheads_q, # [attn]
nheads_kv, # [attn]
head_dim, # [attn]
block_size, # [page]
tree,
n_repeats,
dtype: torch.dtype = torch.float16,
device: str = "cuda:0",
seed: int = 0,
baselines=["all"], # ['all', 'fa', 'vllm-fa', 'ta'], None for none
):
# os.environ["DISABLE_STREAM"] = "1"
print(f"[DEBUG] (pytest) (tree): {tree}")
seq_group, num_blocks = generate_tree_seqs(tree, block_size)
tree_benchmark(
seq_group,
num_blocks,
nheads_q,
nheads_kv,
head_dim,
block_size,
n_repeats,
dtype,
device,
seed,
baselines=baselines,
enable_timing=False,
) # ['pa-naive', 'fa', 'ta']
def test_tree_attn_manual():
"""run a single test with manual sequences (debug)"""
# seq_lens = [17, 36, 44, 44, 32, 36, 44, 44]
# block_table = [[1, 2], [1, 2, 3], [4, 5, 6], [4, 5, 7],
# [4, 8], [1, 9, 10], [4, 5, 11], [4, 5, 12]]
seq_lens = [33, 33, 33]
block_table = [[0, 1, 5], [0, 2, 6], [3, 4, 7]]
# seq_lens = [38, 38, 15, 15]
# block_table = [[1, 2, 6], [1, 5, 8], [3, 0, 0], [7, 0, 0]]
block_size = 16
num_blocks = 13
seq_group = create_seq_group(block_table, seq_lens, block_size)
nheads = 32
nheads_kv = 8
head_dim = 128 # 128
tree_benchmark(
seq_group,
num_blocks,
nheads,
nheads_kv,
head_dim,
block_size,
n_repeats=20,
dtype=torch.float16,
seed=0,
device="cuda:0",
baselines=["vllm-fa", "ta", "fa"],
enable_timing=False,
)
if __name__ == "__main__":
# test_tree_attn_manual()
test_tree_attn_kvcache(
nheads_q=32,
nheads_kv=8,
head_dim=128,
block_size=32,
tree="1,256_256,32",
n_repeats=1,
dtype=torch.float16,
device="cuda:0",
seed=0,
baselines=["all"],
)