You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add support to customized vectordb and embedding functions (#161)
* Add custom embedding function
* Add support to custom vector db
* Improve docstring
* Improve docstring
* Improve docstring
* Add support to customized is_termination_msg fucntion
* Add a test for customize vector db with lancedb
* Fix tests
* Add test for embedding_function
* Update docstring
retrieve_config: Optional[Dict] =None, # config for the retrieve agent
71
72
**kwargs,
72
73
):
@@ -82,14 +83,17 @@ def __init__(
82
83
the number of auto reply reaches the max_consecutive_auto_reply.
83
84
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
84
85
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
86
+
is_termination_msg (function): a function that takes a message in the form of a dictionary
87
+
and returns a boolean value indicating if this received message is a termination message.
88
+
The dict can contain the following keys: "content", "role", "name", "function_call".
85
89
retrieve_config (dict or None): config for the retrieve agent.
86
90
To use default config, set to None. Otherwise, set to a dictionary with the following keys:
87
91
- task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System
88
92
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
89
-
- client (Optional, chromadb.Client): the chromadb client.
90
-
If key not provided, a default client `chromadb.Client()` will be used.
93
+
- client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
94
+
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
91
95
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
92
-
or the url to a single file. If key not provided, a default path `./docs` will be used.
96
+
or the url to a single file. Default is None, which works only if the collection is already created.
93
97
- collection_name (Optional, str): the name of the collection.
94
98
If key not provided, a default name `autogen-docs` will be used.
95
99
- model (Optional, str): the model to use for the retrieve chat.
@@ -106,16 +110,45 @@ def __init__(
106
110
If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models
107
111
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
108
112
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
113
+
- embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None,
114
+
SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or
115
+
other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
109
116
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
110
117
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
111
118
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
112
119
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
113
120
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
114
-
This is the same as that used in chromadb. Default is False.
121
+
This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
115
122
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
116
123
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
117
124
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
118
125
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
126
+
127
+
Example of overriding retrieve_docs:
128
+
If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code.
129
+
```python
130
+
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
"""Create a vector db from all the files in a given directory."""
269
+
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
270
+
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
271
+
you prepared your own vector db.
272
+
273
+
Args:
274
+
dir_path (str): the path to the directory, file or url.
275
+
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
276
+
client (Optional, API): the chromadb client. Default is None.
277
+
db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
278
+
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
279
+
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
280
+
will be recreated if it already exists.
281
+
chunk_mode (Optional, str): the chunk mode. Default is "multi_lines".
282
+
must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True.
283
+
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
284
+
embedding_function is not None.
285
+
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
286
+
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
287
+
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
"""Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db
331
+
and query function.
332
+
333
+
Args:
334
+
query_texts (List[str]): the query texts.
335
+
n_results (Optional, int): the number of results to return. Default is 10.
336
+
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
337
+
db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
338
+
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
339
+
search_string (Optional, str): the search string. Default is "".
340
+
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
341
+
embedding_function is not None.
342
+
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
343
+
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
344
+
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
345
+
346
+
Returns:
347
+
QueryResult: the query result. The format is:
348
+
class QueryResult(TypedDict):
349
+
ids: List[IDs]
350
+
embeddings: Optional[List[List[Embedding]]]
351
+
documents: Optional[List[List[Document]]]
352
+
metadatas: Optional[List[List[Metadata]]]
353
+
distances: Optional[List[List[float]]]
354
+
"""
305
355
ifclientisNone:
306
356
client=chromadb.PersistentClient(path=db_path)
307
357
# the collection's embedding function is always the default one, but we want to use the one we used to create the
308
358
# collection. So we compute the embeddings ourselves and pass it to the query function.
0 commit comments