Skip to content

Commit ca137d1

Browse files
Add support for save_table(..., mode="overwrite") to StatementExecutionBackend (#74)
1 parent 8921e0f commit ca137d1

File tree

3 files changed

+69
-10
lines changed

3 files changed

+69
-10
lines changed

src/databricks/labs/lsql/backends.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,14 @@ def fetch(self, sql: str) -> Iterator[Row]:
146146
return self._sql.fetch_all(sql)
147147

148148
def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"):
149-
if mode == "overwrite":
150-
msg = "Overwrite mode is not yet supported"
151-
raise NotImplementedError(msg)
152149
rows = self._filter_none_rows(rows, klass)
153150
self.create_table(full_name, klass)
154151
if len(rows) == 0:
155152
return
156153
fields = dataclasses.fields(klass)
157154
field_names = [f.name for f in fields]
155+
if mode == "overwrite":
156+
self.execute(f"TRUNCATE TABLE {full_name}")
158157
for i in range(0, len(rows), self._max_records_per_batch):
159158
batch = rows[i : i + self._max_records_per_batch]
160159
vals = "), (".join(self._row_to_sql(r, fields) for r in batch)
@@ -283,10 +282,9 @@ def fetch(self, sql) -> Iterator[Row]:
283282
return iter(rows)
284283

285284
def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"):
286-
if mode == "overwrite":
287-
msg = "Overwrite mode is not yet supported"
288-
raise NotImplementedError(msg)
289285
rows = self._filter_none_rows(rows, klass)
286+
if mode == "overwrite":
287+
self._save_table = []
290288
if klass.__class__ == type:
291289
row_factory = self._row_factory(klass)
292290
rows = [row_factory(*dataclasses.astuple(r)) for r in rows]

tests/integration/test_deployment.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,14 @@ def test_deploys_database(ws, env_or_skip, make_random):
2222
rows = list(sql_backend.fetch(f"SELECT * FROM {schema}.some"))
2323

2424
assert rows == [Row(name="abc", id=1)]
25+
26+
27+
def test_overwrite(ws, env_or_skip, make_random):
28+
schema = "default"
29+
sql_backend = StatementExecutionBackend(ws, env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"))
30+
31+
sql_backend.save_table(f"{schema}.foo", [views.Foo("abc", True)], views.Foo, "append")
32+
sql_backend.save_table(f"{schema}.foo", [views.Foo("xyz", True)], views.Foo, "overwrite")
33+
rows = list(sql_backend.fetch(f"SELECT * FROM {schema}.some"))
34+
35+
assert rows == [Row(name="xyz", id=1)]

tests/unit/test_backends.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,47 @@ def test_statement_execution_backend_fetch_happy():
9393
assert [Row(id=1), Row(id=2), Row(id=3)] == result
9494

9595

96-
def test_statement_execution_backend_save_table_overwrite(mocker):
97-
seb = StatementExecutionBackend(mocker.Mock(), "abc")
98-
with pytest.raises(NotImplementedError):
99-
seb.save_table("a.b.c", [1, 2, 3], Bar, mode="overwrite")
96+
def test_statement_execution_backend_save_table_overwrite_empty_table():
97+
ws = create_autospec(WorkspaceClient)
98+
ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse(
99+
status=StatementStatus(state=StatementState.SUCCEEDED)
100+
)
101+
seb = StatementExecutionBackend(ws, "abc")
102+
seb.save_table("a.b.c", [Baz("1")], Baz, mode="overwrite")
103+
ws.statement_execution.execute_statement.assert_has_calls(
104+
[
105+
mock.call(
106+
warehouse_id="abc",
107+
statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second STRING) USING DELTA",
108+
catalog=None,
109+
schema=None,
110+
disposition=None,
111+
format=Format.JSON_ARRAY,
112+
byte_limit=None,
113+
wait_timeout=None,
114+
),
115+
mock.call(
116+
warehouse_id="abc",
117+
statement="TRUNCATE TABLE a.b.c",
118+
catalog=None,
119+
schema=None,
120+
disposition=None,
121+
format=Format.JSON_ARRAY,
122+
byte_limit=None,
123+
wait_timeout=None,
124+
),
125+
mock.call(
126+
warehouse_id="abc",
127+
statement="INSERT INTO a.b.c (first, second) VALUES ('1', NULL)",
128+
catalog=None,
129+
schema=None,
130+
disposition=None,
131+
format=Format.JSON_ARRAY,
132+
byte_limit=None,
133+
wait_timeout=None,
134+
),
135+
]
136+
)
100137

101138

102139
def test_statement_execution_backend_save_table_empty_records():
@@ -357,3 +394,16 @@ def test_mock_backend_rows_dsl():
357394
Row(foo=1, bar=2),
358395
Row(foo=3, bar=4),
359396
]
397+
398+
399+
def test_mock_backend_overwrite():
400+
mock_backend = MockBackend()
401+
mock_backend.save_table("a.b.c", [Foo("a1", True), Foo("c2", False)], Foo, "append")
402+
mock_backend.save_table("a.b.c", [Foo("aa", True), Foo("bb", False)], Foo, "overwrite")
403+
mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo, "overwrite")
404+
405+
assert mock_backend.rows_written_for("a.b.c", "append") == []
406+
assert mock_backend.rows_written_for("a.b.c", "overwrite") == [
407+
Row(first="aaa", second=True),
408+
Row(first="bbb", second=False),
409+
]

0 commit comments

Comments
 (0)