@@ -92,6 +92,7 @@ def __init__(
9292 logits_all : bool = False ,
9393 embedding : bool = False ,
9494 offload_kqv : bool = True ,
95+ flash_attn : bool = False ,
9596 # Sampling Params
9697 last_n_tokens_size : int = 64 ,
9798 # LoRA Params
@@ -168,6 +169,7 @@ def __init__(
168169 logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
169170 embedding: Embedding mode only.
170171 offload_kqv: Offload K, Q, V to GPU.
172+ flash_attn: Use flash attention.
171173 last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
172174 lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
173175 lora_path: Path to a LoRA file to apply to the model.
@@ -310,6 +312,7 @@ def __init__(
310312 ) # Must be set to True for speculative decoding
311313 self .context_params .embeddings = embedding # TODO: Rename to embeddings
312314 self .context_params .offload_kqv = offload_kqv
315+ self .context_params .flash_attn = flash_attn
313316 # KV cache quantization
314317 if type_k is not None :
315318 self .context_params .type_k = type_k
@@ -1774,6 +1777,7 @@ def __getstate__(self):
17741777 logits_all = self .context_params .logits_all ,
17751778 embedding = self .context_params .embeddings ,
17761779 offload_kqv = self .context_params .offload_kqv ,
1780+ flash_offload = self .context_params .flash_offload ,
17771781 # Sampling Params
17781782 last_n_tokens_size = self .last_n_tokens_size ,
17791783 # LoRA Params
0 commit comments