Skip to content

Commit 3ea4739

Browse files
authored
Replace unittests in providers-apache tests by pure pytest (apache#27948)
1 parent 70a9980 commit 3ea4739

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+540
-685
lines changed

tests/providers/apache/beam/hooks/test_beam.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@
2020
import os
2121
import re
2222
import subprocess
23-
import unittest
2423
from unittest import mock
2524
from unittest.mock import MagicMock
2625

2726
import pytest
28-
from parameterized import parameterized
2927

3028
from airflow.exceptions import AirflowException
3129
from airflow.providers.apache.beam.hooks.beam import BeamCommandRunner, BeamHook, beam_options_to_args
@@ -58,7 +56,7 @@
5856
"""
5957

6058

61-
class TestBeamHook(unittest.TestCase):
59+
class TestBeamHook:
6260
@mock.patch(BEAM_STRING.format("BeamCommandRunner"))
6361
@mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.39.0")
6462
def test_start_python_pipeline(self, mock_check_output, mock_runner):
@@ -106,18 +104,19 @@ def test_start_python_pipeline_unsupported_option(self, mock_check_output):
106104
process_line_callback=MagicMock(),
107105
)
108106

109-
@parameterized.expand(
107+
@pytest.mark.parametrize(
108+
"py_interpreter",
110109
[
111-
("default_to_python3", "python3"),
112-
("major_version_2", "python2"),
113-
("major_version_3", "python3"),
114-
("minor_version", "python3.6"),
115-
]
110+
pytest.param("python", id="default python"),
111+
pytest.param("python2", id="major python version 2.x"),
112+
pytest.param("python3", id="major python version 3.x"),
113+
pytest.param("python3.6", id="major.minor python version"),
114+
],
116115
)
117116
@mock.patch(BEAM_STRING.format("BeamCommandRunner"))
118117
@mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.39.0")
119118
def test_start_python_pipeline_with_custom_interpreter(
120-
self, _, py_interpreter, mock_check_output, mock_runner
119+
self, mock_check_output, mock_runner, py_interpreter
121120
):
122121
hook = BeamHook(runner=DEFAULT_RUNNER)
123122
wait_for_done = mock_runner.return_value.wait_for_done
@@ -144,23 +143,24 @@ def test_start_python_pipeline_with_custom_interpreter(
144143
)
145144
wait_for_done.assert_called_once_with()
146145

147-
@parameterized.expand(
146+
@pytest.mark.parametrize(
147+
"current_py_requirements, current_py_system_site_packages",
148148
[
149-
(["foo-bar"], False),
150-
(["foo-bar"], True),
151-
([], True),
152-
]
149+
pytest.param("foo-bar", False, id="requirements without system site-packages"),
150+
pytest.param("foo-bar", True, id="requirements with system site-packages"),
151+
pytest.param([], True, id="only system site-packages"),
152+
],
153153
)
154154
@mock.patch(BEAM_STRING.format("prepare_virtualenv"))
155155
@mock.patch(BEAM_STRING.format("BeamCommandRunner"))
156156
@mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.39.0")
157157
def test_start_python_pipeline_with_non_empty_py_requirements_and_without_system_packages(
158158
self,
159-
current_py_requirements,
160-
current_py_system_site_packages,
161159
mock_check_output,
162160
mock_runner,
163161
mock_virtualenv,
162+
current_py_requirements,
163+
current_py_system_site_packages,
164164
):
165165
hook = BeamHook(runner=DEFAULT_RUNNER)
166166
wait_for_done = mock_runner.return_value.wait_for_done
@@ -204,7 +204,7 @@ def test_start_python_pipeline_with_empty_py_requirements_and_without_system_pac
204204
wait_for_done = mock_runner.return_value.wait_for_done
205205
process_line_callback = MagicMock()
206206

207-
with self.assertRaisesRegex(AirflowException, "Invalid method invocation."):
207+
with pytest.raises(AirflowException, match=r"Invalid method invocation\."):
208208
hook.start_python_pipeline(
209209
variables=copy.deepcopy(BEAM_VARIABLES_PY),
210210
py_file=PY_FILE,
@@ -302,20 +302,18 @@ def test_start_go_pipeline_without_go_installed_raises(self, mock_which):
302302
mock_which.return_value = None
303303
hook = BeamHook(runner=DEFAULT_RUNNER)
304304

305-
with self.assertRaises(AirflowException) as ex_ctx:
305+
error_message = (
306+
r"You need to have Go installed to run beam go pipeline\. See .* "
307+
"installation guide. If you are running airflow in Docker see more info at '.*'"
308+
)
309+
with pytest.raises(AirflowException, match=error_message):
306310
hook.start_go_pipeline(
307311
go_file=GO_FILE,
308312
variables=copy.deepcopy(BEAM_VARIABLES_GO),
309313
)
310314

311-
assert (
312-
"You need to have Go installed to run beam go pipeline. See https://go.dev/doc/install "
313-
"installation guide. If you are running airflow in Docker see more info at "
314-
"'https://airflow.apache.org/docs/docker-stack/recipes.html'." == str(ex_ctx.exception)
315-
)
316315

317-
318-
class TestBeamRunner(unittest.TestCase):
316+
class TestBeamRunner:
319317
@mock.patch("airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log")
320318
@mock.patch("subprocess.Popen")
321319
@mock.patch("select.select")
@@ -343,18 +341,20 @@ def poll_resp_error():
343341
mock_popen.assert_called_once_with(
344342
cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True, cwd=None
345343
)
346-
self.assertRaises(Exception, beam.wait_for_done)
344+
with pytest.raises(Exception):
345+
beam.wait_for_done()
347346

348347

349-
class TestBeamOptionsToArgs(unittest.TestCase):
350-
@parameterized.expand(
348+
class TestBeamOptionsToArgs:
349+
@pytest.mark.parametrize(
350+
"options, expected_args",
351351
[
352352
({"key": "val"}, ["--key=val"]),
353353
({"key": None}, ["--key"]),
354354
({"key": True}, ["--key"]),
355355
({"key": False}, ["--key=False"]),
356356
({"key": ["a", "b", "c"]}, ["--key=a", "--key=b", "--key=c"]),
357-
]
357+
],
358358
)
359359
def test_beam_options_to_args(self, options, expected_args):
360360
args = beam_options_to_args(options)

tests/providers/apache/beam/operators/test_beam.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
import unittest
2019
from unittest import mock
2120
from unittest.mock import MagicMock
2221

@@ -51,8 +50,8 @@
5150
TEST_IMPERSONATION_ACCOUNT = "[email protected]"
5251

5352

54-
class TestBeamRunPythonPipelineOperator(unittest.TestCase):
55-
def setUp(self):
53+
class TestBeamRunPythonPipelineOperator:
54+
def setup_method(self):
5655
self.operator = BeamRunPythonPipelineOperator(
5756
task_id=TASK_ID,
5857
py_file=PY_FILE,
@@ -63,13 +62,13 @@ def setUp(self):
6362

6463
def test_init(self):
6564
"""Test BeamRunPythonPipelineOperator instance is properly initialized."""
66-
self.assertEqual(self.operator.task_id, TASK_ID)
67-
self.assertEqual(self.operator.py_file, PY_FILE)
68-
self.assertEqual(self.operator.runner, DEFAULT_RUNNER)
69-
self.assertEqual(self.operator.py_options, PY_OPTIONS)
70-
self.assertEqual(self.operator.py_interpreter, PY_INTERPRETER)
71-
self.assertEqual(self.operator.default_pipeline_options, DEFAULT_OPTIONS_PYTHON)
72-
self.assertEqual(self.operator.pipeline_options, EXPECTED_ADDITIONAL_OPTIONS)
65+
assert self.operator.task_id == TASK_ID
66+
assert self.operator.py_file == PY_FILE
67+
assert self.operator.runner == DEFAULT_RUNNER
68+
assert self.operator.py_options == PY_OPTIONS
69+
assert self.operator.py_interpreter == PY_INTERPRETER
70+
assert self.operator.default_pipeline_options == DEFAULT_OPTIONS_PYTHON
71+
assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
7372

7473
@mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
7574
@mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
@@ -180,8 +179,8 @@ def test_on_kill_direct_runner(self, _, dataflow_mock, __):
180179
dataflow_cancel_job.assert_not_called()
181180

182181

183-
class TestBeamRunJavaPipelineOperator(unittest.TestCase):
184-
def setUp(self):
182+
class TestBeamRunJavaPipelineOperator:
183+
def setup_method(self):
185184
self.operator = BeamRunJavaPipelineOperator(
186185
task_id=TASK_ID,
187186
jar=JAR_FILE,
@@ -192,12 +191,12 @@ def setUp(self):
192191

193192
def test_init(self):
194193
"""Test BeamRunJavaPipelineOperator instance is properly initialized."""
195-
self.assertEqual(self.operator.task_id, TASK_ID)
196-
self.assertEqual(self.operator.runner, DEFAULT_RUNNER)
197-
self.assertEqual(self.operator.default_pipeline_options, DEFAULT_OPTIONS_JAVA)
198-
self.assertEqual(self.operator.job_class, JOB_CLASS)
199-
self.assertEqual(self.operator.jar, JAR_FILE)
200-
self.assertEqual(self.operator.pipeline_options, ADDITIONAL_OPTIONS)
194+
assert self.operator.task_id == TASK_ID
195+
assert self.operator.runner == DEFAULT_RUNNER
196+
assert self.operator.default_pipeline_options == DEFAULT_OPTIONS_JAVA
197+
assert self.operator.job_class == JOB_CLASS
198+
assert self.operator.jar == JAR_FILE
199+
assert self.operator.pipeline_options == ADDITIONAL_OPTIONS
201200

202201
@mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
203202
@mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
@@ -299,8 +298,8 @@ def test_on_kill_direct_runner(self, _, dataflow_mock, __):
299298
dataflow_cancel_job.assert_not_called()
300299

301300

302-
class TestBeamRunGoPipelineOperator(unittest.TestCase):
303-
def setUp(self):
301+
class TestBeamRunGoPipelineOperator:
302+
def setup_method(self):
304303
self.operator = BeamRunGoPipelineOperator(
305304
task_id=TASK_ID,
306305
go_file=GO_FILE,
@@ -310,11 +309,11 @@ def setUp(self):
310309

311310
def test_init(self):
312311
"""Test BeamRunGoPipelineOperator instance is properly initialized."""
313-
self.assertEqual(self.operator.task_id, TASK_ID)
314-
self.assertEqual(self.operator.go_file, GO_FILE)
315-
self.assertEqual(self.operator.runner, DEFAULT_RUNNER)
316-
self.assertEqual(self.operator.default_pipeline_options, DEFAULT_OPTIONS_PYTHON)
317-
self.assertEqual(self.operator.pipeline_options, EXPECTED_ADDITIONAL_OPTIONS)
312+
assert self.operator.task_id == TASK_ID
313+
assert self.operator.go_file == GO_FILE
314+
assert self.operator.runner == DEFAULT_RUNNER
315+
assert self.operator.default_pipeline_options == DEFAULT_OPTIONS_PYTHON
316+
assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
318317

319318
@mock.patch(
320319
"tempfile.TemporaryDirectory",

tests/providers/apache/cassandra/hooks/test_cassandra.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
import unittest
2120
from unittest import mock
2221

2322
import pytest
@@ -35,14 +34,14 @@
3534

3635

3736
@pytest.mark.integration("cassandra")
38-
class TestCassandraHook(unittest.TestCase):
39-
def setUp(self):
37+
class TestCassandraHook:
38+
def setup_method(self):
4039
db.merge_conn(
4140
Connection(
4241
conn_id="cassandra_test",
4342
conn_type="cassandra",
4443
host="host-1,host-2",
45-
port="9042",
44+
port=9042,
4645
schema="test_keyspace",
4746
extra='{"load_balancing_policy":"TokenAwarePolicy","protocol_version":4}',
4847
)
@@ -52,7 +51,7 @@ def setUp(self):
5251
conn_id="cassandra_default_with_schema",
5352
conn_type="cassandra",
5453
host="cassandra",
55-
port="9042",
54+
port=9042,
5655
schema="s",
5756
)
5857
)

tests/providers/apache/cassandra/sensors/test_record.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
import unittest
2120
from unittest.mock import patch
2221

2322
from airflow.providers.apache.cassandra.sensors.record import CassandraRecordSensor
@@ -27,7 +26,7 @@
2726
TEST_CASSANDRA_KEY = {"foo": "bar"}
2827

2928

30-
class TestCassandraRecordSensor(unittest.TestCase):
29+
class TestCassandraRecordSensor:
3130
@patch("airflow.providers.apache.cassandra.sensors.record.CassandraHook")
3231
def test_poke(self, mock_hook):
3332
sensor = CassandraRecordSensor(

tests/providers/apache/cassandra/sensors/test_table.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
import unittest
2120
from unittest.mock import patch
2221

2322
from airflow.providers.apache.cassandra.sensors.table import CassandraTableSensor
@@ -27,7 +26,7 @@
2726
TEST_CASSANDRA_TABLE_WITH_KEYSPACE = "keyspacename.tablename"
2827

2928

30-
class TestCassandraTableSensor(unittest.TestCase):
29+
class TestCassandraTableSensor:
3130
@patch("airflow.providers.apache.cassandra.sensors.table.CassandraHook")
3231
def test_poke(self, mock_hook):
3332
sensor = CassandraTableSensor(

tests/providers/apache/drill/hooks/test_drill.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
import unittest
2120
from unittest.mock import MagicMock
2221

2322
from airflow.providers.apache.drill.hooks.drill import DrillHook
2423

2524

26-
class TestDrillHook(unittest.TestCase):
27-
def setUp(self):
25+
class TestDrillHook:
26+
def setup_method(self):
2827
self.cur = MagicMock(rowcount=0)
2928
self.conn = conn = MagicMock()
3029
self.conn.login = "drill_user"

tests/providers/apache/drill/operators/test_drill.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
import unittest
21-
2220
import pytest
2321

2422
from airflow.models.dag import DAG
@@ -32,13 +30,13 @@
3230

3331

3432
@pytest.mark.backend("drill")
35-
class TestDrillOperator(unittest.TestCase):
36-
def setUp(self):
33+
class TestDrillOperator:
34+
def setup_method(self):
3735
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
3836
dag = DAG(TEST_DAG_ID, default_args=args)
3937
self.dag = dag
4038

41-
def tearDown(self):
39+
def teardown_method(self):
4240
tables_to_drop = ["dfs.tmp.test_airflow"]
4341
from airflow.providers.apache.drill.hooks.drill import DrillHook
4442

0 commit comments

Comments
 (0)