Skip to content

Commit f03a5ab

Browse files
authored
Added MockBackend.rows("col1", "col2")[(...), (...)] helper (#49)
This PR makes testing with `MockBackend` easier.
1 parent f394f4d commit f03a5ab

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

src/databricks/labs/lsql/backends.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,32 @@ def rows_written_for(self, full_name: str, mode: str) -> list[DataclassInstance]
301301
rows += stub_rows
302302
return rows
303303

304+
@staticmethod
305+
def rows(*column_names: str):
306+
"""This method is used to create rows for the mock backend."""
307+
number_of_columns = len(column_names)
308+
row_factory = Row.factory(list(column_names))
309+
310+
class MagicFactory:
311+
"""This class is used to create rows for the mock backend."""
312+
313+
def __getitem__(self, tuples: list[tuple | list] | tuple[list | tuple]) -> list[Row]:
314+
if not isinstance(tuples, (list, tuple)):
315+
raise TypeError(f"Expected list or tuple, got {type(tuples)}")
316+
# fix sloppy input
317+
if tuples and not isinstance(tuples[0], (list, tuple)):
318+
tuples = [tuples]
319+
out = []
320+
for record in tuples:
321+
if not isinstance(record, (list, tuple)):
322+
raise TypeError(f"Expected list or tuple, got {type(record)}")
323+
if number_of_columns != len(record):
324+
raise TypeError(f"Expected {number_of_columns} columns, got {len(record)}: {record}")
325+
out.append(row_factory(*record))
326+
return out
327+
328+
return MagicFactory()
329+
304330
@staticmethod
305331
def _row_factory(klass: Dataclass) -> type:
306332
return Row.factory([f.name for f in dataclasses.fields(klass)])

src/databricks/labs/lsql/core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ def as_dict(self) -> dict[str, Any]:
6666
"""Convert the row to a dictionary with the same conventions as Databricks SDK."""
6767
return dict(zip(self.__columns__, self, strict=True))
6868

69+
def __eq__(self, other):
70+
"""Check if the rows are equal."""
71+
if not isinstance(other, Row):
72+
return False
73+
# compare rows as dictionaries, because the order
74+
# of fields in constructor is not guaranteed
75+
return self.as_dict() == other.as_dict()
76+
6977
def __contains__(self, item):
7078
"""Check if the column is in the row."""
7179
return item in self.__columns__

tests/unit/test_backends.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,14 @@ def test_mock_backend_save_table():
346346
Row(first="aaa", second=True),
347347
Row(first="bbb", second=False),
348348
]
349+
350+
351+
def test_mock_backend_rows_dsl():
352+
rows = MockBackend.rows("foo", "bar")[
353+
[1, 2],
354+
(3, 4),
355+
]
356+
assert rows == [
357+
Row(foo=1, bar=2),
358+
Row(foo=3, bar=4),
359+
]

0 commit comments

Comments
 (0)