Skip to content

Commit fba6aab

Browse files
Neeratyoymfeurer
andauthored
Making some unit tests work (#1000)
* Making some unit tests work * Waiting for dataset to be processed * Minor test collection fix * Template to handle missing tasks * Accounting for more missing tasks: * Fixing some more unit tests * Simplifying check_task_existence * black changes * Minor formatting * Handling task exists check * Testing edited check task func * Flake fix * More retries on connection error * Adding max_retries to config default * Update database retry unit test * Print to debug hash exception * Fixing checksum unit test * Retry on _download_text_file * Update datasets_tutorial.py * Update custom_flow_tutorial.py * Update test_study_functions.py * Update test_dataset_functions.py * more retries, but also more time between retries * allow for even more retries on get calls * Catching failed get task * undo stupid change * fix one more test * Refactoring md5 hash check inside _send_request * Fixing a fairly common unit test fail * Reverting loose check on unit test Co-authored-by: Matthias Feurer <[email protected]>
1 parent 16799ad commit fba6aab

File tree

18 files changed

+394
-146
lines changed

18 files changed

+394
-146
lines changed

examples/30_extended/custom_flow_tutorial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@
8282
# This allows people to specify auto-sklearn hyperparameters used in this flow.
8383
# In general, using a subflow is not required.
8484
#
85-
# Note: flow 15275 is not actually the right flow on the test server,
85+
# Note: flow 9313 is not actually the right flow on the test server,
8686
# but that does not matter for this demonstration.
8787

88-
autosklearn_flow = openml.flows.get_flow(15275) # auto-sklearn 0.5.1
88+
autosklearn_flow = openml.flows.get_flow(9313) # auto-sklearn 0.5.1
8989
subflow = dict(components=OrderedDict(automl_tool=autosklearn_flow),)
9090

9191
####################################################################################################
@@ -120,7 +120,7 @@
120120
OrderedDict([("oml:name", "time"), ("oml:value", 120), ("oml:component", flow_id)]),
121121
]
122122

123-
task_id = 1408 # Iris Task
123+
task_id = 1965 # Iris Task
124124
task = openml.tasks.get_task(task_id)
125125
dataset_id = task.get_dataset().dataset_id
126126

examples/30_extended/datasets_tutorial.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112

113113
############################################################################
114114
# Edit a created dataset
115-
# =================================================
115+
# ======================
116116
# This example uses the test server, to avoid editing a dataset on the main server.
117117
openml.config.start_using_configuration_for_example()
118118
############################################################################
@@ -143,18 +143,23 @@
143143
# tasks associated with it. To edit critical fields of a dataset (without tasks) owned by you,
144144
# configure the API key:
145145
# openml.config.apikey = 'FILL_IN_OPENML_API_KEY'
146-
data_id = edit_dataset(564, default_target_attribute="y")
147-
print(f"Edited dataset ID: {data_id}")
148-
146+
# This example here only shows a failure when trying to work on a dataset not owned by you:
147+
try:
148+
data_id = edit_dataset(1, default_target_attribute="shape")
149+
except openml.exceptions.OpenMLServerException as e:
150+
print(e)
149151

150152
############################################################################
151153
# Fork dataset
154+
# ============
152155
# Used to create a copy of the dataset with you as the owner.
153156
# Use this API only if you are unable to edit the critical fields (default_target_attribute,
154157
# ignore_attribute, row_id_attribute) of a dataset through the edit_dataset API.
155158
# After the dataset is forked, you can edit the new version of the dataset using edit_dataset.
156159

157-
data_id = fork_dataset(564)
160+
data_id = fork_dataset(1)
161+
print(data_id)
162+
data_id = edit_dataset(data_id, default_target_attribute="shape")
158163
print(f"Forked dataset ID: {data_id}")
159164

160165
openml.config.stop_using_configuration_for_example()

openml/_api_calls.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import hashlib
55
import logging
66
import requests
7+
import xml
78
import xmltodict
89
from typing import Dict, Optional
910

@@ -105,20 +106,9 @@ def _download_text_file(
105106

106107
logging.info("Starting [%s] request for the URL %s", "get", source)
107108
start = time.time()
108-
response = __read_url(source, request_method="get")
109+
response = __read_url(source, request_method="get", md5_checksum=md5_checksum)
109110
downloaded_file = response.text
110111

111-
if md5_checksum is not None:
112-
md5 = hashlib.md5()
113-
md5.update(downloaded_file.encode("utf-8"))
114-
md5_checksum_download = md5.hexdigest()
115-
if md5_checksum != md5_checksum_download:
116-
raise OpenMLHashException(
117-
"Checksum {} of downloaded file is unequal to the expected checksum {}.".format(
118-
md5_checksum_download, md5_checksum
119-
)
120-
)
121-
122112
if output_path is None:
123113
logging.info(
124114
"%.7fs taken for [%s] request for the URL %s", time.time() - start, "get", source,
@@ -163,22 +153,33 @@ def _read_url_files(url, data=None, file_elements=None):
163153
return response
164154

165155

166-
def __read_url(url, request_method, data=None):
156+
def __read_url(url, request_method, data=None, md5_checksum=None):
167157
data = {} if data is None else data
168158
if config.apikey is not None:
169159
data["api_key"] = config.apikey
160+
return _send_request(
161+
request_method=request_method, url=url, data=data, md5_checksum=md5_checksum
162+
)
163+
170164

171-
return _send_request(request_method=request_method, url=url, data=data)
165+
def __is_checksum_equal(downloaded_file, md5_checksum=None):
166+
if md5_checksum is None:
167+
return True
168+
md5 = hashlib.md5()
169+
md5.update(downloaded_file.encode("utf-8"))
170+
md5_checksum_download = md5.hexdigest()
171+
if md5_checksum == md5_checksum_download:
172+
return True
173+
return False
172174

173175

174-
def _send_request(
175-
request_method, url, data, files=None,
176-
):
177-
n_retries = config.connection_n_retries
176+
def _send_request(request_method, url, data, files=None, md5_checksum=None):
177+
n_retries = max(1, min(config.connection_n_retries, config.max_retries))
178+
178179
response = None
179180
with requests.Session() as session:
180181
# Start at one to have a non-zero multiplier for the sleep
181-
for i in range(1, n_retries + 1):
182+
for retry_counter in range(1, n_retries + 1):
182183
try:
183184
if request_method == "get":
184185
response = session.get(url, params=data)
@@ -189,25 +190,36 @@ def _send_request(
189190
else:
190191
raise NotImplementedError()
191192
__check_response(response=response, url=url, file_elements=files)
193+
if request_method == "get" and not __is_checksum_equal(response.text, md5_checksum):
194+
raise OpenMLHashException(
195+
"Checksum of downloaded file is unequal to the expected checksum {} "
196+
"when downloading {}.".format(md5_checksum, url)
197+
)
192198
break
193199
except (
194200
requests.exceptions.ConnectionError,
195201
requests.exceptions.SSLError,
196202
OpenMLServerException,
203+
xml.parsers.expat.ExpatError,
204+
OpenMLHashException,
197205
) as e:
198206
if isinstance(e, OpenMLServerException):
199-
if e.code != 107:
200-
# 107 is a database connection error - only then do retries
207+
if e.code not in [107, 500]:
208+
# 107: database connection error
209+
# 500: internal server error
201210
raise
202-
else:
203-
wait_time = 0.3
204-
else:
205-
wait_time = 0.1
206-
if i == n_retries:
207-
raise e
211+
elif isinstance(e, xml.parsers.expat.ExpatError):
212+
if request_method != "get" or retry_counter >= n_retries:
213+
raise OpenMLServerError(
214+
"Unexpected server error when calling {}. Please contact the "
215+
"developers!\nStatus code: {}\n{}".format(
216+
url, response.status_code, response.text,
217+
)
218+
)
219+
if retry_counter >= n_retries:
220+
raise
208221
else:
209-
time.sleep(wait_time * i)
210-
continue
222+
time.sleep(retry_counter)
211223
if response is None:
212224
raise ValueError("This should never happen!")
213225
return response
@@ -230,6 +242,8 @@ def __parse_server_exception(
230242
raise OpenMLServerError("URI too long! ({})".format(url))
231243
try:
232244
server_exception = xmltodict.parse(response.text)
245+
except xml.parsers.expat.ExpatError:
246+
raise
233247
except Exception:
234248
# OpenML has a sophisticated error system
235249
# where information about failures is provided. try to parse this

openml/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def set_file_log_level(file_output_level: int):
8787
"server": "https://www.openml.org/api/v1/xml",
8888
"cachedir": os.path.expanduser(os.path.join("~", ".openml", "cache")),
8989
"avoid_duplicate_runs": "True",
90-
"connection_n_retries": 2,
90+
"connection_n_retries": 10,
91+
"max_retries": 20,
9192
}
9293

9394
config_file = os.path.expanduser(os.path.join("~", ".openml", "config"))
@@ -116,6 +117,7 @@ def get_server_base_url() -> str:
116117

117118
# Number of retries if the connection breaks
118119
connection_n_retries = _defaults["connection_n_retries"]
120+
max_retries = _defaults["max_retries"]
119121

120122

121123
class ConfigurationForExamples:
@@ -183,6 +185,7 @@ def _setup():
183185
global cache_directory
184186
global avoid_duplicate_runs
185187
global connection_n_retries
188+
global max_retries
186189

187190
# read config file, create cache directory
188191
try:
@@ -207,10 +210,11 @@ def _setup():
207210

208211
avoid_duplicate_runs = config.getboolean("FAKE_SECTION", "avoid_duplicate_runs")
209212
connection_n_retries = config.get("FAKE_SECTION", "connection_n_retries")
210-
if connection_n_retries > 20:
213+
max_retries = config.get("FAKE_SECTION", "max_retries")
214+
if connection_n_retries > max_retries:
211215
raise ValueError(
212-
"A higher number of retries than 20 is not allowed to keep the "
213-
"server load reasonable"
216+
"A higher number of retries than {} is not allowed to keep the "
217+
"server load reasonable".format(max_retries)
214218
)
215219

216220

openml/testing.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import shutil
77
import sys
88
import time
9-
from typing import Dict
9+
from typing import Dict, Union, cast
1010
import unittest
1111
import warnings
12+
import pandas as pd
1213

1314
# Currently, importing oslo raises a lot of warning that it will stop working
1415
# under python3.8; remove this once they disappear
@@ -18,6 +19,7 @@
1819

1920
import openml
2021
from openml.tasks import TaskType
22+
from openml.exceptions import OpenMLServerException
2123

2224
import logging
2325

@@ -252,6 +254,55 @@ def _check_fold_timing_evaluations(
252254
self.assertLessEqual(evaluation, max_val)
253255

254256

257+
def check_task_existence(
258+
task_type: TaskType, dataset_id: int, target_name: str, **kwargs
259+
) -> Union[int, None]:
260+
"""Checks if any task with exists on test server that matches the meta data.
261+
262+
Parameter
263+
---------
264+
task_type : openml.tasks.TaskType
265+
dataset_id : int
266+
target_name : str
267+
268+
Return
269+
------
270+
int, None
271+
"""
272+
return_val = None
273+
tasks = openml.tasks.list_tasks(task_type=task_type, output_format="dataframe")
274+
if len(tasks) == 0:
275+
return None
276+
tasks = cast(pd.DataFrame, tasks).loc[tasks["did"] == dataset_id]
277+
if len(tasks) == 0:
278+
return None
279+
tasks = tasks.loc[tasks["target_feature"] == target_name]
280+
if len(tasks) == 0:
281+
return None
282+
task_match = []
283+
for task_id in tasks["tid"].to_list():
284+
task_match.append(task_id)
285+
try:
286+
task = openml.tasks.get_task(task_id)
287+
except OpenMLServerException:
288+
# can fail if task_id deleted by another parallely run unit test
289+
task_match.pop(-1)
290+
return_val = None
291+
continue
292+
for k, v in kwargs.items():
293+
if getattr(task, k) != v:
294+
# even if one of the meta-data key mismatches, then task_id is not a match
295+
task_match.pop(-1)
296+
break
297+
# if task_id is retained in the task_match list, it passed all meta key-value matches
298+
if len(task_match) == 1:
299+
return_val = task_id
300+
break
301+
if len(task_match) == 0:
302+
return_val = None
303+
return return_val
304+
305+
255306
try:
256307
from sklearn.impute import SimpleImputer
257308
except ImportError:
@@ -275,4 +326,4 @@ def cat(X):
275326
return X.dtypes == "category"
276327

277328

278-
__all__ = ["TestBase", "SimpleImputer", "CustomImputer", "cat", "cont"]
329+
__all__ = ["TestBase", "SimpleImputer", "CustomImputer", "cat", "cont", "check_task_existence"]

openml/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from functools import wraps
1010
import collections
1111

12+
import openml
1213
import openml._api_calls
1314
import openml.exceptions
1415
from . import config

tests/test_datasets/test_dataset_functions.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
DATASETS_CACHE_DIR_NAME,
3737
)
3838
from openml.datasets import fork_dataset, edit_dataset
39+
from openml.tasks import TaskType, create_task
3940

4041

4142
class TestOpenMLDataset(TestBase):
@@ -414,9 +415,8 @@ def test__getarff_md5_issue(self):
414415
}
415416
self.assertRaisesRegex(
416417
OpenMLHashException,
417-
"Checksum ad484452702105cbf3d30f8deaba39a9 of downloaded file "
418-
"is unequal to the expected checksum abc. "
419-
"Raised when downloading dataset 5.",
418+
"Checksum of downloaded file is unequal to the expected checksum abc when downloading "
419+
"https://www.openml.org/data/download/61. Raised when downloading dataset 5.",
420420
_get_dataset_arff,
421421
description,
422422
)
@@ -498,6 +498,7 @@ def test_upload_dataset_with_url(self):
498498
)
499499
self.assertIsInstance(dataset.dataset_id, int)
500500

501+
@pytest.mark.flaky()
501502
def test_data_status(self):
502503
dataset = OpenMLDataset(
503504
"%s-UploadTestWithURL" % self._get_sentinel(),
@@ -1350,7 +1351,7 @@ def test_data_edit_errors(self):
13501351
"original_data_url, default_target_attribute, row_id_attribute, "
13511352
"ignore_attribute or paper_url to edit.",
13521353
edit_dataset,
1353-
data_id=564,
1354+
data_id=64, # blood-transfusion-service-center
13541355
)
13551356
# Check server exception when unknown dataset is provided
13561357
self.assertRaisesRegex(
@@ -1360,15 +1361,32 @@ def test_data_edit_errors(self):
13601361
data_id=999999,
13611362
description="xor operation dataset",
13621363
)
1364+
1365+
# Need to own a dataset to be able to edit meta-data
1366+
# Will be creating a forked version of an existing dataset to allow the unit test user
1367+
# to edit meta-data of a dataset
1368+
did = fork_dataset(1)
1369+
self._wait_for_dataset_being_processed(did)
1370+
TestBase._mark_entity_for_removal("data", did)
1371+
# Need to upload a task attached to this data to test edit failure
1372+
task = create_task(
1373+
task_type=TaskType.SUPERVISED_CLASSIFICATION,
1374+
dataset_id=did,
1375+
target_name="class",
1376+
estimation_procedure_id=1,
1377+
)
1378+
task = task.publish()
1379+
TestBase._mark_entity_for_removal("task", task.task_id)
13631380
# Check server exception when owner/admin edits critical fields of dataset with tasks
13641381
self.assertRaisesRegex(
13651382
OpenMLServerException,
13661383
"Critical features default_target_attribute, row_id_attribute and ignore_attribute "
13671384
"can only be edited for datasets without any tasks.",
13681385
edit_dataset,
1369-
data_id=223,
1386+
data_id=did,
13701387
default_target_attribute="y",
13711388
)
1389+
13721390
# Check server exception when a non-owner or non-admin tries to edit critical fields
13731391
self.assertRaisesRegex(
13741392
OpenMLServerException,

0 commit comments

Comments
 (0)