Skip to content

Commit 58fcd29

Browse files
authored
fix: Apply normalization to all dttm columns (#25147)
1 parent 17792a5 commit 58fcd29

File tree

5 files changed

+160
-10
lines changed

5 files changed

+160
-10
lines changed

superset/common/query_context_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def _apply_granularity(
185185
filter
186186
for filter in query_object.filter
187187
if filter["col"] != filter_to_remove
188+
or filter["op"] != "TEMPORAL_RANGE"
188189
]
189190

190191
def _apply_filters(self, query_object: QueryObject) -> None:

superset/common/query_context_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,11 @@ def _get_timestamp_format(
282282
datasource = self._qc_datasource
283283
labels = tuple(
284284
label
285-
for label in [
285+
for label in {
286286
*get_base_axis_labels(query_object.columns),
287+
*[col for col in query_object.columns or [] if isinstance(col, str)],
287288
query_object.granularity,
288-
]
289+
}
289290
if datasource
290291
# Query datasource didn't support `get_column`
291292
and hasattr(datasource, "get_column")

superset/common/query_object_factory.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,24 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from datetime import datetime
1920
from typing import Any, TYPE_CHECKING
2021

2122
from superset.common.chart_data import ChartDataResultType
2223
from superset.common.query_object import QueryObject
2324
from superset.common.utils.time_range_utils import get_since_until_from_time_range
24-
from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType
25+
from superset.utils.core import (
26+
apply_max_row_limit,
27+
DatasourceDict,
28+
DatasourceType,
29+
FilterOperator,
30+
QueryObjectFilterClause,
31+
)
2532

2633
if TYPE_CHECKING:
2734
from sqlalchemy.orm import sessionmaker
2835

29-
from superset.connectors.base.models import BaseDatasource
36+
from superset.connectors.base.models import BaseColumn, BaseDatasource
3037
from superset.daos.datasource import DatasourceDAO
3138

3239

@@ -66,6 +73,10 @@ def create( # pylint: disable=too-many-arguments
6673
)
6774
kwargs["from_dttm"] = from_dttm
6875
kwargs["to_dttm"] = to_dttm
76+
if datasource_model_instance and kwargs.get("filters", []):
77+
kwargs["filters"] = self._process_filters(
78+
datasource_model_instance, kwargs["filters"]
79+
)
6980
return QueryObject(
7081
datasource=datasource_model_instance,
7182
extras=extras,
@@ -102,3 +113,54 @@ def _process_row_limit(
102113
# light version of the view.utils.core
103114
# import view.utils require application context
104115
# Todo: move it and the view.utils.core to utils package
116+
117+
def _process_filters(
118+
self, datasource: BaseDatasource, query_filters: list[QueryObjectFilterClause]
119+
) -> list[QueryObjectFilterClause]:
120+
def get_dttm_filter_value(
121+
value: Any, col: BaseColumn, date_format: str
122+
) -> int | str:
123+
if not isinstance(value, int):
124+
return value
125+
if date_format in {"epoch_ms", "epoch_s"}:
126+
if date_format == "epoch_s":
127+
value = str(value)
128+
else:
129+
value = str(value * 1000)
130+
else:
131+
dttm = datetime.utcfromtimestamp(value / 1000)
132+
value = dttm.strftime(date_format)
133+
134+
if col.type in col.num_types:
135+
value = int(value)
136+
return value
137+
138+
for query_filter in query_filters:
139+
if query_filter.get("op") == FilterOperator.TEMPORAL_RANGE:
140+
continue
141+
filter_col = query_filter.get("col")
142+
if not isinstance(filter_col, str):
143+
continue
144+
column = datasource.get_column(filter_col)
145+
if not column:
146+
continue
147+
filter_value = query_filter.get("val")
148+
149+
date_format = column.python_date_format
150+
if not date_format and datasource.db_extra:
151+
date_format = datasource.db_extra.get(
152+
"python_date_format_by_column_name", {}
153+
).get(column.column_name)
154+
155+
if column.is_dttm and date_format:
156+
if isinstance(filter_value, list):
157+
query_filter["val"] = [
158+
get_dttm_filter_value(value, column, date_format)
159+
for value in filter_value
160+
]
161+
else:
162+
query_filter["val"] = get_dttm_filter_value(
163+
filter_value, column, date_format
164+
)
165+
166+
return query_filters

tests/integration_tests/query_context_tests.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -836,11 +836,9 @@ def test_special_chars_in_column_name(app_context, physical_dataset):
836836

837837
query_object = qc.queries[0]
838838
df = qc.get_df_payload(query_object)["df"]
839-
if query_object.datasource.database.backend == "sqlite":
840-
# sqlite returns string as timestamp column
841-
assert df["time column with spaces"][0] == "2002-01-03 00:00:00"
842-
assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00"
843-
else:
839+
840+
# sqlite doesn't have timestamp columns
841+
if query_object.datasource.database.backend != "sqlite":
844842
assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03"
845843
assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01"
846844

tests/unit_tests/common/test_query_object_factory.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,45 @@ def session_factory() -> Mock:
4343
return Mock()
4444

4545

46+
class SimpleDatasetColumn:
47+
def __init__(self, col_params: dict[str, Any]):
48+
self.__dict__.update(col_params)
49+
50+
51+
TEMPORAL_COLUMN_NAMES = ["temporal_column", "temporal_column_with_python_date_format"]
52+
TEMPORAL_COLUMNS = {
53+
TEMPORAL_COLUMN_NAMES[0]: SimpleDatasetColumn(
54+
{
55+
"column_name": TEMPORAL_COLUMN_NAMES[0],
56+
"is_dttm": True,
57+
"python_date_format": None,
58+
"type": "string",
59+
"num_types": ["BIGINT"],
60+
}
61+
),
62+
TEMPORAL_COLUMN_NAMES[1]: SimpleDatasetColumn(
63+
{
64+
"column_name": TEMPORAL_COLUMN_NAMES[1],
65+
"type": "BIGINT",
66+
"is_dttm": True,
67+
"python_date_format": "%Y",
68+
"num_types": ["BIGINT"],
69+
}
70+
),
71+
}
72+
73+
4674
@fixture
4775
def connector_registry() -> Mock:
48-
return Mock(spec=["get_datasource"])
76+
datasource_dao_mock = Mock(spec=["get_datasource"])
77+
datasource_dao_mock.get_datasource.return_value = Mock()
78+
datasource_dao_mock.get_datasource().get_column = Mock(
79+
side_effect=lambda col_name: TEMPORAL_COLUMNS[col_name]
80+
if col_name in TEMPORAL_COLUMN_NAMES
81+
else Mock()
82+
)
83+
datasource_dao_mock.get_datasource().db_extra = None
84+
return datasource_dao_mock
4985

5086

5187
def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int:
@@ -112,3 +148,55 @@ def test_query_context_null_post_processing_op(
112148
raw_query_context["result_type"], **raw_query_object
113149
)
114150
assert query_object.post_processing == []
151+
152+
def test_query_context_no_python_date_format_filters(
153+
self,
154+
query_object_factory: QueryObjectFactory,
155+
raw_query_context: dict[str, Any],
156+
):
157+
raw_query_object = raw_query_context["queries"][0]
158+
raw_query_object["filters"].append(
159+
{"col": TEMPORAL_COLUMN_NAMES[0], "op": "==", "val": 315532800000}
160+
)
161+
query_object = query_object_factory.create(
162+
raw_query_context["result_type"],
163+
raw_query_context["datasource"],
164+
**raw_query_object
165+
)
166+
assert query_object.filter[3]["val"] == 315532800000
167+
168+
def test_query_context_python_date_format_filters(
169+
self,
170+
query_object_factory: QueryObjectFactory,
171+
raw_query_context: dict[str, Any],
172+
):
173+
raw_query_object = raw_query_context["queries"][0]
174+
raw_query_object["filters"].append(
175+
{"col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": 315532800000}
176+
)
177+
query_object = query_object_factory.create(
178+
raw_query_context["result_type"],
179+
raw_query_context["datasource"],
180+
**raw_query_object
181+
)
182+
assert query_object.filter[3]["val"] == 1980
183+
184+
def test_query_context_python_date_format_filters_list_of_values(
185+
self,
186+
query_object_factory: QueryObjectFactory,
187+
raw_query_context: dict[str, Any],
188+
):
189+
raw_query_object = raw_query_context["queries"][0]
190+
raw_query_object["filters"].append(
191+
{
192+
"col": TEMPORAL_COLUMN_NAMES[1],
193+
"op": "==",
194+
"val": [315532800000, 631152000000],
195+
}
196+
)
197+
query_object = query_object_factory.create(
198+
raw_query_context["result_type"],
199+
raw_query_context["datasource"],
200+
**raw_query_object
201+
)
202+
assert query_object.filter[3]["val"] == [1980, 1990]

0 commit comments

Comments
 (0)