2323"""
2424import os
2525import argparse
26- from typing import Literal , Union
26+ from typing import List , Literal , Union
2727
2828import uvicorn
2929
3030from llama_cpp .server .app import create_app , Settings
3131
32- def get_non_none_base_types (annotation ):
33- if not hasattr (annotation , "__args__" ):
34- return annotation
35- return [arg for arg in annotation .__args__ if arg is not type (None )][0 ]
36-
3732def get_base_type (annotation ):
3833 if getattr (annotation , '__origin__' , None ) is Literal :
3934 return type (annotation .__args__ [0 ])
4035 elif getattr (annotation , '__origin__' , None ) is Union :
4136 non_optional_args = [arg for arg in annotation .__args__ if arg is not type (None )]
4237 if non_optional_args :
4338 return get_base_type (non_optional_args [0 ])
39+ elif getattr (annotation , '__origin__' , None ) is list or getattr (annotation , '__origin__' , None ) is List :
40+ return get_base_type (annotation .__args__ [0 ])
4441 else :
4542 return annotation
4643
44+ def contains_list_type (annotation ) -> bool :
45+ origin = getattr (annotation , '__origin__' , None )
46+
47+ if origin is list or origin is List :
48+ return True
49+ elif origin in (Literal , Union ):
50+ return any (contains_list_type (arg ) for arg in annotation .__args__ )
51+ else :
52+ return False
53+
54+
4755if __name__ == "__main__" :
4856 parser = argparse .ArgumentParser ()
4957 for name , field in Settings .model_fields .items ():
@@ -53,6 +61,7 @@ def get_base_type(annotation):
5361 parser .add_argument (
5462 f"--{ name } " ,
5563 dest = name ,
64+ nargs = "*" if contains_list_type (field .annotation ) else None ,
5665 type = get_base_type (field .annotation ) if field .annotation is not None else str ,
5766 help = description ,
5867 )
0 commit comments