Skip to content

Commit f72e4b6

Browse files
[Serving] Allow LLModel + LLMPromptArtifact without ModelArtifact instance (#9090)
1 parent 766816a commit f72e4b6

File tree

3 files changed

+109
-10
lines changed

3 files changed

+109
-10
lines changed

mlrun/errors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,9 @@ def __init__(self, models_errors: dict[str:str], *args) -> None:
269269
super().__init__(self.__repr__(), *args)
270270

271271
def __repr__(self):
272-
return f"ModelRunnerError: {repr(self.models_errors)}"
272+
return "ModelRunnerError: " + ";\n".join(
273+
f"{model} {msg}" for model, msg in self.models_errors.items()
274+
)
273275

274276
def __copy__(self):
275277
return type(self)(models_errors=self.models_errors)

mlrun/serving/states.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,14 +1184,18 @@ def load(self) -> None:
11841184
if self._execution_mechanism == storey.ParallelExecutionMechanisms.asyncio:
11851185
if self.__class__.predict_async is Model.predict_async:
11861186
raise mlrun.errors.ModelRunnerError(
1187-
f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict_async() "
1188-
f"is not implemented"
1187+
{
1188+
self.name: f"is running with {self._execution_mechanism} "
1189+
f"execution_mechanism but predict_async() is not implemented"
1190+
}
11891191
)
11901192
else:
11911193
if self.__class__.predict is Model.predict:
11921194
raise mlrun.errors.ModelRunnerError(
1193-
f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict() "
1194-
f"is not implemented"
1195+
{
1196+
self.name: f"is running with {self._execution_mechanism} execution_mechanism but predict() "
1197+
f"is not implemented"
1198+
}
11951199
)
11961200

11971201
def _load_artifacts(self) -> None:
@@ -1210,7 +1214,9 @@ def _get_artifact_object(
12101214
uri = proxy_uri or self.artifact_uri
12111215
if uri:
12121216
if mlrun.datastore.is_store_uri(uri):
1213-
artifact, _ = mlrun.store_manager.get_store_artifact(uri)
1217+
artifact, _ = mlrun.store_manager.get_store_artifact(
1218+
uri, allow_empty_resources=True
1219+
)
12141220
return artifact
12151221
else:
12161222
raise ValueError(
@@ -1419,6 +1425,24 @@ async def predict_async(
14191425
)
14201426
return body
14211427

1428+
def init(self):
1429+
super().init()
1430+
1431+
if not self.model_provider:
1432+
if self._execution_mechanism != storey.ParallelExecutionMechanisms.asyncio:
1433+
unchanged_predict = self.__class__.predict is LLModel.predict
1434+
predict_function_name = "predict"
1435+
else:
1436+
unchanged_predict = (
1437+
self.__class__.predict_async is LLModel.predict_async
1438+
)
1439+
predict_function_name = "predict_async"
1440+
if unchanged_predict:
1441+
raise mlrun.errors.MLRunRuntimeError(
1442+
f"Model provider could not be determined for model '{self.name}',"
1443+
f" and the {predict_function_name} function was not overridden."
1444+
)
1445+
14221446
def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
14231447
llm_prompt_artifact = self._get_invocation_artifact(origin_name)
14241448
messages, invocation_config = self.enrich_prompt(

tests/serving/test_async_flow.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,21 @@ def predict(self, body, **kwargs):
263263
return body
264264

265265

266+
class DummyLLM(LLModel):
267+
def predict(self, body: typing.Any, **kwargs):
268+
return body
269+
270+
271+
class DummyAsyncLLM(LLModel):
272+
async def predict_async(self, body: typing.Any, **kwargs):
273+
return body
274+
275+
276+
class DummyAsyncLLMWithoutAsyncPredict(LLModel):
277+
def predict(self, body: typing.Any, **kwargs):
278+
return body
279+
280+
266281
class MyPklModel(Model):
267282
def __init__(self, name, raise_exception, artifact_uri, **kwargs):
268283
super().__init__(
@@ -1167,8 +1182,8 @@ def test_using_model_without_predict_implementation(execution_mechanism: str):
11671182
function.to_mock_server()
11681183
method_name = "predict()" if execution_mechanism != "asyncio" else "predict_async()"
11691184
expected_msg = (
1170-
f"'model_without_predict is running with {execution_mechanism} execution_mechanism but "
1171-
f"{method_name} is not implemented'"
1185+
f"model_without_predict is running with {execution_mechanism} execution_mechanism but "
1186+
f"{method_name} is not implemented"
11721187
)
11731188
assert expected_msg in str(exc_info.value)
11741189

@@ -1208,8 +1223,8 @@ def test_shared_using_model_without_predict_implementation(execution_mechanism:
12081223
"predict()" if execution_mechanism != "asyncio" else "predict_async()"
12091224
)
12101225
expected_msg = (
1211-
f"'model_without_predict_shared is running with {execution_mechanism} execution_mechanism but "
1212-
f"{method_name} is not implemented'"
1226+
f"model_without_predict_shared is running with {execution_mechanism} execution_mechanism but "
1227+
f"{method_name} is not implemented"
12131228
)
12141229
assert expected_msg in str(exc_info.value)
12151230

@@ -1295,3 +1310,61 @@ def test_configure_model_runner_step_max_threads_processes(concurrency: str):
12951310
), "Max threads not configured properly"
12961311
server.test(body={"n": 1})
12971312
server.wait_for_completion()
1313+
1314+
1315+
@pytest.mark.parametrize(
1316+
"model_class, raise_exception",
1317+
[
1318+
(
1319+
"LLModel",
1320+
True,
1321+
), # LLModel should raise error because predict was not overridden
1322+
# DummyAsyncLLMWithoutAsyncPredict should raise error because async_predict was not overridden:
1323+
("DummyAsyncLLMWithoutAsyncPredict", True),
1324+
("DummyLLM", False),
1325+
("DummyAsyncLLM", False),
1326+
],
1327+
)
1328+
def test_llmodel_without_model_artifact(model_class, raise_exception):
1329+
is_async = model_class in ("DummyAsyncLLM", "DummyAsyncLLMWithoutAsyncPredict")
1330+
execution_mechanism = "asyncio" if is_async else "naive"
1331+
predict_function_name = "predict_async" if is_async else "predict"
1332+
function = mlrun.new_function("tests", kind="serving")
1333+
graph = function.set_topology("flow", engine="async")
1334+
model_runner_step = ModelRunnerStep(name="model-runner")
1335+
project = mlrun.new_project("llmodel-without-model-artifact", save=False)
1336+
llm_artifact = project.log_llm_prompt(
1337+
"my_llm",
1338+
prompt_template=[
1339+
{"role": "user", "content": "What is the capital city of {country}?"}
1340+
],
1341+
prompt_legend={"country": {"field": None, "description": "Great"}},
1342+
)
1343+
1344+
model_runner_step.add_model(
1345+
model_class=model_class,
1346+
execution_mechanism=execution_mechanism,
1347+
endpoint_name="my-model",
1348+
model_artifact=llm_artifact,
1349+
)
1350+
graph.to(model_runner_step).respond()
1351+
server = None
1352+
with unittest.mock.patch(
1353+
"mlrun.datastore.datastore.get_store_resource",
1354+
return_value=llm_artifact,
1355+
):
1356+
try:
1357+
if raise_exception:
1358+
with pytest.raises(
1359+
mlrun.errors.MLRunRuntimeError,
1360+
match=f"Model provider could not be determined for model 'my-model', and the"
1361+
f" {predict_function_name} function was not overridden",
1362+
):
1363+
server = function.to_mock_server()
1364+
else:
1365+
server = function.to_mock_server()
1366+
resp = server.test(body={"country": "france"})
1367+
assert resp == {"country": "france"}
1368+
finally:
1369+
if server:
1370+
server.wait_for_completion()

0 commit comments

Comments
 (0)