@@ -263,6 +263,21 @@ def predict(self, body, **kwargs):
263263 return body
264264
265265
266+ class DummyLLM (LLModel ):
267+ def predict (self , body : typing .Any , ** kwargs ):
268+ return body
269+
270+
271+ class DummyAsyncLLM (LLModel ):
272+ async def predict_async (self , body : typing .Any , ** kwargs ):
273+ return body
274+
275+
276+ class DummyAsyncLLMWithoutAsyncPredict (LLModel ):
277+ def predict (self , body : typing .Any , ** kwargs ):
278+ return body
279+
280+
266281class MyPklModel (Model ):
267282 def __init__ (self , name , raise_exception , artifact_uri , ** kwargs ):
268283 super ().__init__ (
@@ -1167,8 +1182,8 @@ def test_using_model_without_predict_implementation(execution_mechanism: str):
11671182 function .to_mock_server ()
11681183 method_name = "predict()" if execution_mechanism != "asyncio" else "predict_async()"
11691184 expected_msg = (
1170- f"' model_without_predict is running with { execution_mechanism } execution_mechanism but "
1171- f"{ method_name } is not implemented' "
1185+ f"model_without_predict is running with { execution_mechanism } execution_mechanism but "
1186+ f"{ method_name } is not implemented"
11721187 )
11731188 assert expected_msg in str (exc_info .value )
11741189
@@ -1208,8 +1223,8 @@ def test_shared_using_model_without_predict_implementation(execution_mechanism:
12081223 "predict()" if execution_mechanism != "asyncio" else "predict_async()"
12091224 )
12101225 expected_msg = (
1211- f"' model_without_predict_shared is running with { execution_mechanism } execution_mechanism but "
1212- f"{ method_name } is not implemented' "
1226+ f"model_without_predict_shared is running with { execution_mechanism } execution_mechanism but "
1227+ f"{ method_name } is not implemented"
12131228 )
12141229 assert expected_msg in str (exc_info .value )
12151230
@@ -1295,3 +1310,61 @@ def test_configure_model_runner_step_max_threads_processes(concurrency: str):
12951310 ), "Max threads not configured properly"
12961311 server .test (body = {"n" : 1 })
12971312 server .wait_for_completion ()
1313+
1314+
1315+ @pytest .mark .parametrize (
1316+ "model_class, raise_exception" ,
1317+ [
1318+ (
1319+ "LLModel" ,
1320+ True ,
1321+ ), # LLModel should raise error because predict was not overridden
1322+ # DummyAsyncLLMWithoutAsyncPredict should raise error because async_predict was not overridden:
1323+ ("DummyAsyncLLMWithoutAsyncPredict" , True ),
1324+ ("DummyLLM" , False ),
1325+ ("DummyAsyncLLM" , False ),
1326+ ],
1327+ )
1328+ def test_llmodel_without_model_artifact (model_class , raise_exception ):
1329+ is_async = model_class in ("DummyAsyncLLM" , "DummyAsyncLLMWithoutAsyncPredict" )
1330+ execution_mechanism = "asyncio" if is_async else "naive"
1331+ predict_function_name = "predict_async" if is_async else "predict"
1332+ function = mlrun .new_function ("tests" , kind = "serving" )
1333+ graph = function .set_topology ("flow" , engine = "async" )
1334+ model_runner_step = ModelRunnerStep (name = "model-runner" )
1335+ project = mlrun .new_project ("llmodel-without-model-artifact" , save = False )
1336+ llm_artifact = project .log_llm_prompt (
1337+ "my_llm" ,
1338+ prompt_template = [
1339+ {"role" : "user" , "content" : "What is the capital city of {country}?" }
1340+ ],
1341+ prompt_legend = {"country" : {"field" : None , "description" : "Great" }},
1342+ )
1343+
1344+ model_runner_step .add_model (
1345+ model_class = model_class ,
1346+ execution_mechanism = execution_mechanism ,
1347+ endpoint_name = "my-model" ,
1348+ model_artifact = llm_artifact ,
1349+ )
1350+ graph .to (model_runner_step ).respond ()
1351+ server = None
1352+ with unittest .mock .patch (
1353+ "mlrun.datastore.datastore.get_store_resource" ,
1354+ return_value = llm_artifact ,
1355+ ):
1356+ try :
1357+ if raise_exception :
1358+ with pytest .raises (
1359+ mlrun .errors .MLRunRuntimeError ,
1360+ match = f"Model provider could not be determined for model 'my-model', and the"
1361+ f" { predict_function_name } function was not overridden" ,
1362+ ):
1363+ server = function .to_mock_server ()
1364+ else :
1365+ server = function .to_mock_server ()
1366+ resp = server .test (body = {"country" : "france" })
1367+ assert resp == {"country" : "france" }
1368+ finally :
1369+ if server :
1370+ server .wait_for_completion ()
0 commit comments