Skip to content

Commit 0b31b2c

Browse files
authored
fix(hive): Regression in #21794 (#22794)
1 parent d091a68 commit 0b31b2c

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

superset/db_engine_specs/hive.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import tempfile
2323
import time
2424
from datetime import datetime
25-
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
25+
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
2626
from urllib import parse
2727

2828
import numpy as np
@@ -576,3 +576,38 @@ def has_implicit_cancel(cls) -> bool:
576576
"""
577577

578578
return True
579+
580+
@classmethod
581+
def get_view_names(
582+
cls,
583+
database: "Database",
584+
inspector: Inspector,
585+
schema: Optional[str],
586+
) -> Set[str]:
587+
"""
588+
Get all the view names within the specified schema.
589+
590+
Per the SQLAlchemy definition if the schema is omitted the database’s default
591+
schema is used, however some dialects infer the request as schema agnostic.
592+
593+
Note that PyHive's Hive SQLAlchemy dialect does not adhere to the specification
594+
where the `get_view_names` method returns both real tables and views. Futhermore
595+
the dialect wrongfully infers the request as schema agnostic when the schema is
596+
omitted.
597+
598+
:param database: The database to inspect
599+
:param inspector: The SQLAlchemy inspector
600+
:param schema: The schema to inspect
601+
:returns: The view names
602+
"""
603+
604+
sql = "SHOW VIEWS"
605+
606+
if schema:
607+
sql += f" IN `{schema}`"
608+
609+
with database.get_raw_connection(schema=schema) as conn:
610+
cursor = conn.cursor()
611+
cursor.execute(sql)
612+
results = cursor.fetchall()
613+
return {row[0] for row in results}

superset/db_engine_specs/presto.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,9 +638,10 @@ def get_view_names(
638638
Per the SQLAlchemy definition if the schema is omitted the database’s default
639639
schema is used, however some dialects infer the request as schema agnostic.
640640
641-
Note that PyHive's Hive and Presto SQLAlchemy dialects do not implement the
642-
`get_view_names` method. To ensure consistency with the `get_table_names` method
643-
the request is deemed schema agnostic when the schema is omitted.
641+
Note that PyHive's Presto SQLAlchemy dialect does not adhere to the
642+
specification as the `get_view_names` method is not defined. Futhermore the
643+
dialect wrongfully infers the request as schema agnostic when the schema is
644+
omitted.
644645
645646
:param database: The database to inspect
646647
:param inspector: The SQLAlchemy inspector

tests/integration_tests/db_engine_specs/hive_tests.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,41 @@ def is_correct_result(data: List, result: List) -> bool:
403403
["ds=01-01-19/hour=1", "ds=01-03-19/hour=1", "ds=01-02-19/hour=2"],
404404
["01-03-19", "1"],
405405
)
406+
407+
408+
def test_get_view_names_with_schema():
409+
database = mock.MagicMock()
410+
mock_execute = mock.MagicMock()
411+
database.get_raw_connection().__enter__().cursor().execute = mock_execute
412+
database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock(
413+
return_value=[["a", "b,", "c"], ["d", "e"]]
414+
)
415+
416+
schema = "schema"
417+
result = HiveEngineSpec.get_view_names(database, mock.Mock(), schema)
418+
mock_execute.assert_called_once_with(f"SHOW VIEWS IN `{schema}`")
419+
assert result == {"a", "d"}
420+
421+
422+
def test_get_view_names_without_schema():
423+
database = mock.MagicMock()
424+
mock_execute = mock.MagicMock()
425+
database.get_raw_connection().__enter__().cursor().execute = mock_execute
426+
database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock(
427+
return_value=[["a", "b,", "c"], ["d", "e"]]
428+
)
429+
result = HiveEngineSpec.get_view_names(database, mock.Mock(), None)
430+
mock_execute.assert_called_once_with("SHOW VIEWS")
431+
assert result == {"a", "d"}
432+
433+
434+
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
435+
@mock.patch("superset.db_engine_specs.hive.HiveEngineSpec.get_view_names")
436+
def test_get_table_names(
437+
mock_get_view_names,
438+
mock_get_table_names,
439+
):
440+
mock_get_view_names.return_value = {"view1", "view2"}
441+
mock_get_table_names.return_value = {"table1", "table2", "view1", "view2"}
442+
tables = HiveEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
443+
assert tables == {"table1", "table2"}

0 commit comments

Comments
 (0)