Skip to content

Commit 861b52d

Browse files
committed
Speed up dataset unit tests
by only loading necessary datasets
1 parent 72576bd commit 861b52d

File tree

1 file changed

+38
-35
lines changed

1 file changed

+38
-35
lines changed

tests/test_datasets/test_dataset.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,6 @@ def setUp(self):
2222
super(OpenMLDatasetTest, self).setUp()
2323
openml.config.server = self.production_server
2424

25-
# Load dataset id 2 - dataset 2 is interesting because it contains
26-
# missing values, categorical features etc.
27-
self.dataset = openml.datasets.get_dataset(2, download_data=False)
28-
# titanic as missing values, categories, and string
29-
self.titanic = openml.datasets.get_dataset(40945, download_data=False)
30-
# these datasets have some boolean features
31-
self.pc4 = openml.datasets.get_dataset(1049, download_data=False)
32-
self.jm1 = openml.datasets.get_dataset(1053, download_data=False)
33-
self.iris = openml.datasets.get_dataset(61, download_data=False)
34-
3525
def test_repr(self):
3626
# create a bare-bones dataset as would be returned by
3727
# create_dataset
@@ -63,7 +53,8 @@ def test__unpack_categories_with_nan_likes(self):
6353

6454
def test_get_data_array(self):
6555
# Basic usage
66-
rval, _, categorical, attribute_names = self.dataset.get_data(dataset_format="array")
56+
dataset = openml.datasets.get_dataset(2, download_data=False)
57+
rval, _, categorical, attribute_names = dataset.get_data(dataset_format="array")
6758
self.assertIsInstance(rval, np.ndarray)
6859
self.assertEqual(rval.dtype, np.float32)
6960
self.assertEqual((898, 39), rval.shape)
@@ -76,12 +67,14 @@ def test_get_data_array(self):
7667
# check that an error is raised when the dataset contains string
7768
err_msg = "PyOpenML cannot handle string when returning numpy arrays"
7869
with pytest.raises(PyOpenMLError, match=err_msg):
79-
self.titanic.get_data(dataset_format="array")
70+
titanic = openml.datasets.get_dataset(40945, download_data=False)
71+
titanic.get_data(dataset_format="array")
8072

8173
def test_get_data_pandas(self):
82-
data, _, _, _ = self.titanic.get_data(dataset_format="dataframe")
74+
titanic = openml.datasets.get_dataset(40945, download_data=False)
75+
data, _, _, _ = titanic.get_data(dataset_format="dataframe")
8376
self.assertTrue(isinstance(data, pd.DataFrame))
84-
self.assertEqual(data.shape[1], len(self.titanic.features))
77+
self.assertEqual(data.shape[1], len(titanic.features))
8578
self.assertEqual(data.shape[0], 1309)
8679
col_dtype = {
8780
"pclass": "uint8",
@@ -102,8 +95,8 @@ def test_get_data_pandas(self):
10295
for col_name in data.columns:
10396
self.assertTrue(data[col_name].dtype.name == col_dtype[col_name])
10497

105-
X, y, _, _ = self.titanic.get_data(
106-
dataset_format="dataframe", target=self.titanic.default_target_attribute
98+
X, y, _, _ = titanic.get_data(
99+
dataset_format="dataframe", target=titanic.default_target_attribute
107100
)
108101
self.assertTrue(isinstance(X, pd.DataFrame))
109102
self.assertTrue(isinstance(y, pd.Series))
@@ -116,19 +109,22 @@ def test_get_data_pandas(self):
116109
def test_get_data_boolean_pandas(self):
117110
# test to check that we are converting properly True and False even
118111
# with some inconsistency when dumping the data on openml
119-
data, _, _, _ = self.jm1.get_data()
112+
jm1 = openml.datasets.get_dataset(1053, download_data=False)
113+
data, _, _, _ = jm1.get_data()
120114
self.assertTrue(data["defects"].dtype.name == "category")
121115
self.assertTrue(set(data["defects"].cat.categories) == {True, False})
122116

123-
data, _, _, _ = self.pc4.get_data()
117+
pc4 = openml.datasets.get_dataset(1049, download_data=False)
118+
data, _, _, _ = pc4.get_data()
124119
self.assertTrue(data["c"].dtype.name == "category")
125120
self.assertTrue(set(data["c"].cat.categories) == {True, False})
126121

127122
def test_get_data_no_str_data_for_nparrays(self):
128123
# check that an error is raised when the dataset contains string
129124
err_msg = "PyOpenML cannot handle string when returning numpy arrays"
130125
with pytest.raises(PyOpenMLError, match=err_msg):
131-
self.titanic.get_data(dataset_format="array")
126+
titanic = openml.datasets.get_dataset(40945, download_data=False)
127+
titanic.get_data(dataset_format="array")
132128

133129
def _check_expected_type(self, dtype, is_cat, col):
134130
if is_cat:
@@ -141,23 +137,25 @@ def _check_expected_type(self, dtype, is_cat, col):
141137
self.assertEqual(dtype.name, expected_type)
142138

143139
def test_get_data_with_rowid(self):
144-
self.dataset.row_id_attribute = "condition"
145-
rval, _, categorical, _ = self.dataset.get_data(include_row_id=True)
140+
dataset = openml.datasets.get_dataset(2, download_data=False)
141+
dataset.row_id_attribute = "condition"
142+
rval, _, categorical, _ = dataset.get_data(include_row_id=True)
146143
self.assertIsInstance(rval, pd.DataFrame)
147144
for (dtype, is_cat, col) in zip(rval.dtypes, categorical, rval):
148145
self._check_expected_type(dtype, is_cat, rval[col])
149146
self.assertEqual(rval.shape, (898, 39))
150147
self.assertEqual(len(categorical), 39)
151148

152-
rval, _, categorical, _ = self.dataset.get_data()
149+
rval, _, categorical, _ = dataset.get_data()
153150
self.assertIsInstance(rval, pd.DataFrame)
154151
for (dtype, is_cat, col) in zip(rval.dtypes, categorical, rval):
155152
self._check_expected_type(dtype, is_cat, rval[col])
156153
self.assertEqual(rval.shape, (898, 38))
157154
self.assertEqual(len(categorical), 38)
158155

159156
def test_get_data_with_target_array(self):
160-
X, y, _, attribute_names = self.dataset.get_data(dataset_format="array", target="class")
157+
dataset = openml.datasets.get_dataset(2, download_data=False)
158+
X, y, _, attribute_names = dataset.get_data(dataset_format="array", target="class")
161159
self.assertIsInstance(X, np.ndarray)
162160
self.assertEqual(X.dtype, np.float32)
163161
self.assertEqual(X.shape, (898, 38))
@@ -167,7 +165,8 @@ def test_get_data_with_target_array(self):
167165
self.assertNotIn("class", attribute_names)
168166

169167
def test_get_data_with_target_pandas(self):
170-
X, y, categorical, attribute_names = self.dataset.get_data(target="class")
168+
dataset = openml.datasets.get_dataset(2, download_data=False)
169+
X, y, categorical, attribute_names = dataset.get_data(target="class")
171170
self.assertIsInstance(X, pd.DataFrame)
172171
for (dtype, is_cat, col) in zip(X.dtypes, categorical, X):
173172
self._check_expected_type(dtype, is_cat, X[col])
@@ -181,50 +180,54 @@ def test_get_data_with_target_pandas(self):
181180
self.assertNotIn("class", attribute_names)
182181

183182
def test_get_data_rowid_and_ignore_and_target(self):
184-
self.dataset.ignore_attribute = ["condition"]
185-
self.dataset.row_id_attribute = ["hardness"]
186-
X, y, categorical, names = self.dataset.get_data(target="class")
183+
dataset = openml.datasets.get_dataset(2, download_data=False)
184+
dataset.ignore_attribute = ["condition"]
185+
dataset.row_id_attribute = ["hardness"]
186+
X, y, categorical, names = dataset.get_data(target="class")
187187
self.assertEqual(X.shape, (898, 36))
188188
self.assertEqual(len(categorical), 36)
189189
cats = [True] * 3 + [False, True, True, False] + [True] * 23 + [False] * 3 + [True] * 3
190190
self.assertListEqual(categorical, cats)
191191
self.assertEqual(y.shape, (898,))
192192

193193
def test_get_data_with_ignore_attributes(self):
194-
self.dataset.ignore_attribute = ["condition"]
195-
rval, _, categorical, _ = self.dataset.get_data(include_ignore_attribute=True)
194+
dataset = openml.datasets.get_dataset(2, download_data=False)
195+
dataset.ignore_attribute = ["condition"]
196+
rval, _, categorical, _ = dataset.get_data(include_ignore_attribute=True)
196197
for (dtype, is_cat, col) in zip(rval.dtypes, categorical, rval):
197198
self._check_expected_type(dtype, is_cat, rval[col])
198199
self.assertEqual(rval.shape, (898, 39))
199200
self.assertEqual(len(categorical), 39)
200201

201-
rval, _, categorical, _ = self.dataset.get_data(include_ignore_attribute=False)
202+
rval, _, categorical, _ = dataset.get_data(include_ignore_attribute=False)
202203
for (dtype, is_cat, col) in zip(rval.dtypes, categorical, rval):
203204
self._check_expected_type(dtype, is_cat, rval[col])
204205
self.assertEqual(rval.shape, (898, 38))
205206
self.assertEqual(len(categorical), 38)
206207

207208
def test_get_data_with_nonexisting_class(self):
209+
dataset = openml.datasets.get_dataset(2, download_data=False)
208210
# This class is using the anneal dataset with labels [1, 2, 3, 4, 5, 'U']. However,
209211
# label 4 does not exist and we test that the features 5 and 'U' are correctly mapped to
210212
# indices 4 and 5, and that nothing is mapped to index 3.
211-
_, y, _, _ = self.dataset.get_data("class", dataset_format="dataframe")
213+
_, y, _, _ = dataset.get_data("class", dataset_format="dataframe")
212214
self.assertEqual(list(y.dtype.categories), ["1", "2", "3", "4", "5", "U"])
213-
_, y, _, _ = self.dataset.get_data("class", dataset_format="array")
215+
_, y, _, _ = dataset.get_data("class", dataset_format="array")
214216
self.assertEqual(np.min(y), 0)
215217
self.assertEqual(np.max(y), 5)
216218
# Check that no label is mapped to 3, since it is reserved for label '4'.
217219
self.assertEqual(np.sum(y == 3), 0)
218220

219221
def test_get_data_corrupt_pickle(self):
220222
# Lazy loaded dataset, populate cache.
221-
self.iris.get_data()
223+
iris = openml.datasets.get_dataset(61, download_data=False)
224+
iris.get_data()
222225
# Corrupt pickle file, overwrite as empty.
223-
with open(self.iris.data_pickle_file, "w") as fh:
226+
with open(iris.data_pickle_file, "w") as fh:
224227
fh.write("")
225228
# Despite the corrupt file, the data should be loaded from the ARFF file.
226229
# A warning message is written to the python logger.
227-
xy, _, _, _ = self.iris.get_data()
230+
xy, _, _, _ = iris.get_data()
228231
self.assertIsInstance(xy, pd.DataFrame)
229232
self.assertEqual(xy.shape, (150, 5))
230233

0 commit comments

Comments
 (0)