1414# KIND, either express or implied. See the License for the
1515# specific language governing permissions and limitations
1616# under the License.
17+ from __future__ import annotations
18+
1719import logging
20+ import re
1821from datetime import datetime
19- from typing import Any , Dict , List , Optional , Type , TYPE_CHECKING
22+ from typing import Any , cast , Dict , List , Optional , Type , TYPE_CHECKING
2023
24+ from flask import current_app
25+ from flask_babel import gettext as __
26+ from marshmallow import fields , Schema
27+ from marshmallow .validate import Range
2128from sqlalchemy import types
29+ from sqlalchemy .engine .url import URL
2230from urllib3 .exceptions import NewConnectionError
2331
24- from superset .db_engine_specs .base import BaseEngineSpec
32+ from superset .databases .utils import make_url_safe
33+ from superset .db_engine_specs .base import (
34+ BaseEngineSpec ,
35+ BasicParametersMixin ,
36+ BasicParametersType ,
37+ BasicPropertiesType ,
38+ )
2539from superset .db_engine_specs .exceptions import SupersetDBAPIDatabaseError
40+ from superset .errors import ErrorLevel , SupersetError , SupersetErrorType
2641from superset .extensions import cache_manager
42+ from superset .utils .core import GenericDataType
43+ from superset .utils .hashing import md5_sha_from_str
44+ from superset .utils .network import is_hostname_valid , is_port_open
2745
2846if TYPE_CHECKING :
29- # prevent circular imports
3047 from superset .models .core import Database
3148
3249logger = logging .getLogger (__name__ )
3350
3451
35- class ClickHouseEngineSpec (BaseEngineSpec ): # pylint: disable=abstract-method
36- """Dialect for ClickHouse analytical DB."""
37-
38- engine = "clickhouse"
39- engine_name = "ClickHouse"
52+ class ClickHouseBaseEngineSpec (BaseEngineSpec ):
53+ """Shared engine spec for ClickHouse."""
4054
4155 time_secondary_columns = True
4256 time_groupby_inline = True
@@ -56,8 +70,78 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
5670 "P1Y" : "toStartOfYear(toDateTime({col}))" ,
5771 }
5872
59- _show_functions_column = "name"
73+ column_type_mappings = (
74+ (
75+ re .compile (r".*Enum.*" , re .IGNORECASE ),
76+ types .String (),
77+ GenericDataType .STRING ,
78+ ),
79+ (
80+ re .compile (r".*Array.*" , re .IGNORECASE ),
81+ types .String (),
82+ GenericDataType .STRING ,
83+ ),
84+ (
85+ re .compile (r".*UUID.*" , re .IGNORECASE ),
86+ types .String (),
87+ GenericDataType .STRING ,
88+ ),
89+ (
90+ re .compile (r".*Bool.*" , re .IGNORECASE ),
91+ types .Boolean (),
92+ GenericDataType .BOOLEAN ,
93+ ),
94+ (
95+ re .compile (r".*String.*" , re .IGNORECASE ),
96+ types .String (),
97+ GenericDataType .STRING ,
98+ ),
99+ (
100+ re .compile (r".*Int\d+.*" , re .IGNORECASE ),
101+ types .INTEGER (),
102+ GenericDataType .NUMERIC ,
103+ ),
104+ (
105+ re .compile (r".*Decimal.*" , re .IGNORECASE ),
106+ types .DECIMAL (),
107+ GenericDataType .NUMERIC ,
108+ ),
109+ (
110+ re .compile (r".*DateTime.*" , re .IGNORECASE ),
111+ types .DateTime (),
112+ GenericDataType .TEMPORAL ,
113+ ),
114+ (
115+ re .compile (r".*Date.*" , re .IGNORECASE ),
116+ types .Date (),
117+ GenericDataType .TEMPORAL ,
118+ ),
119+ )
120+
121+ @classmethod
122+ def epoch_to_dttm (cls ) -> str :
123+ return "{col}"
124+
125+ @classmethod
126+ def convert_dttm (
127+ cls , target_type : str , dttm : datetime , db_extra : Optional [Dict [str , Any ]] = None
128+ ) -> Optional [str ]:
129+ sqla_type = cls .get_sqla_column_type (target_type )
130+
131+ if isinstance (sqla_type , types .Date ):
132+ return f"toDate('{ dttm .date ().isoformat ()} ')"
133+ if isinstance (sqla_type , types .DateTime ):
134+ return f"""toDateTime('{ dttm .isoformat (sep = " " , timespec = "seconds" )} ')"""
135+ return None
136+
60137
138+ class ClickHouseEngineSpec (ClickHouseBaseEngineSpec ):
139+ """Engine spec for clickhouse_sqlalchemy connector"""
140+
141+ engine = "clickhouse"
142+ engine_name = "ClickHouse"
143+
144+ _show_functions_column = "name"
61145 supports_file_upload = False
62146
63147 @classmethod
@@ -73,21 +157,9 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
73157 return exception
74158 return new_exception (str (exception ))
75159
76- @classmethod
77- def convert_dttm (
78- cls , target_type : str , dttm : datetime , db_extra : Optional [Dict [str , Any ]] = None
79- ) -> Optional [str ]:
80- sqla_type = cls .get_sqla_column_type (target_type )
81-
82- if isinstance (sqla_type , types .Date ):
83- return f"toDate('{ dttm .date ().isoformat ()} ')"
84- if isinstance (sqla_type , types .DateTime ):
85- return f"""toDateTime('{ dttm .isoformat (sep = " " , timespec = "seconds" )} ')"""
86- return None
87-
88160 @classmethod
89161 @cache_manager .cache .memoize ()
90- def get_function_names (cls , database : " Database" ) -> List [str ]:
162+ def get_function_names (cls , database : Database ) -> List [str ]:
91163 """
92164 Get a list of function names that are able to be called on the database.
93165 Used for SQL Lab autocomplete.
@@ -123,3 +195,201 @@ def get_function_names(cls, database: "Database") -> List[str]:
123195
124196 # otherwise, return no function names to prevent errors
125197 return []
198+
199+
200+ class ClickHouseParametersSchema (Schema ):
201+ username = fields .String (allow_none = True , description = __ ("Username" ))
202+ password = fields .String (allow_none = True , description = __ ("Password" ))
203+ host = fields .String (required = True , description = __ ("Hostname or IP address" ))
204+ port = fields .Integer (
205+ allow_none = True ,
206+ description = __ ("Database port" ),
207+ validate = Range (min = 0 , max = 65535 ),
208+ )
209+ database = fields .String (allow_none = True , description = __ ("Database name" ))
210+ encryption = fields .Boolean (
211+ default = True , description = __ ("Use an encrypted connection to the database" )
212+ )
213+ query = fields .Dict (
214+ keys = fields .Str (), values = fields .Raw (), description = __ ("Additional parameters" )
215+ )
216+
217+
218+ try :
219+ from clickhouse_connect .common import set_setting
220+ from clickhouse_connect .datatypes .format import set_default_formats
221+
222+ # override default formats for compatibility
223+ set_default_formats (
224+ "FixedString" ,
225+ "string" ,
226+ "IPv*" ,
227+ "string" ,
228+ "signed" ,
229+ "UUID" ,
230+ "string" ,
231+ "*Int256" ,
232+ "string" ,
233+ "*Int128" ,
234+ "string" ,
235+ )
236+ set_setting (
237+ "product_name" ,
238+ f"superset/{ current_app .config .get ('VERSION_STRING' , 'dev' )} " ,
239+ )
240+ except ImportError : # ClickHouse Connect not installed, do nothing
241+ pass
242+
243+
244+ class ClickHouseConnectEngineSpec (ClickHouseEngineSpec , BasicParametersMixin ):
245+ """Engine spec for clickhouse-connect connector"""
246+
247+ engine = "clickhousedb"
248+ engine_name = "ClickHouse Connect"
249+
250+ default_driver = "connect"
251+ _function_names : List [str ] = []
252+
253+ sqlalchemy_uri_placeholder = (
254+ "clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]"
255+ )
256+ parameters_schema = ClickHouseParametersSchema ()
257+ encryption_parameters = {"secure" : "true" }
258+
259+ @classmethod
260+ def get_dbapi_exception_mapping (cls ) -> Dict [Type [Exception ], Type [Exception ]]:
261+ return {}
262+
263+ @classmethod
264+ def get_dbapi_mapped_exception (cls , exception : Exception ) -> Exception :
265+ new_exception = cls .get_dbapi_exception_mapping ().get (type (exception ))
266+ if new_exception == SupersetDBAPIDatabaseError :
267+ return SupersetDBAPIDatabaseError ("Connection failed" )
268+ if not new_exception :
269+ return exception
270+ return new_exception (str (exception ))
271+
272+ @classmethod
273+ def get_function_names (cls , database : Database ) -> List [str ]:
274+ # pylint: disable=import-outside-toplevel,import-error
275+ from clickhouse_connect .driver .exceptions import ClickHouseError
276+
277+ if cls ._function_names :
278+ return cls ._function_names
279+ try :
280+ names = database .get_df (
281+ "SELECT name FROM system.functions UNION ALL "
282+ + "SELECT name FROM system.table_functions LIMIT 10000"
283+ )["name" ].tolist ()
284+ cls ._function_names = names
285+ return names
286+ except ClickHouseError :
287+ logger .exception ("Error retrieving system.functions" )
288+ return []
289+
290+ @classmethod
291+ def get_datatype (cls , type_code : str ) -> str :
292+ # keep it lowercase, as ClickHouse types aren't typical SHOUTCASE ANSI SQL
293+ return type_code
294+
295+ @classmethod
296+ def build_sqlalchemy_uri (
297+ cls ,
298+ parameters : BasicParametersType ,
299+ encrypted_extra : Optional [Dict [str , str ]] = None ,
300+ ) -> str :
301+ url_params = parameters .copy ()
302+ if url_params .get ("encryption" ):
303+ query = parameters .get ("query" , {}).copy ()
304+ query .update (cls .encryption_parameters )
305+ url_params ["query" ] = query
306+ if not url_params .get ("database" ):
307+ url_params ["database" ] = "__default__"
308+ url_params .pop ("encryption" , None )
309+ return str (URL (f"{ cls .engine } +{ cls .default_driver } " , ** url_params ))
310+
311+ @classmethod
312+ def get_parameters_from_uri (
313+ cls , uri : str , encrypted_extra : Optional [Dict [str , Any ]] = None
314+ ) -> BasicParametersType :
315+ url = make_url_safe (uri )
316+ query = url .query
317+ if "secure" in query :
318+ encryption = url .query .get ("secure" ) == "true"
319+ query .pop ("secure" )
320+ else :
321+ encryption = False
322+ return BasicParametersType (
323+ username = url .username ,
324+ password = url .password ,
325+ host = url .host ,
326+ port = url .port ,
327+ database = "" if url .database == "__default__" else cast (str , url .database ),
328+ query = dict (query ),
329+ encryption = encryption ,
330+ )
331+
332+ @classmethod
333+ def validate_parameters (
334+ cls , properties : BasicPropertiesType
335+ ) -> List [SupersetError ]:
336+ # pylint: disable=import-outside-toplevel,import-error
337+ from clickhouse_connect .driver import default_port
338+
339+ parameters = properties .get ("parameters" , {})
340+ host = parameters .get ("host" , None )
341+ if not host :
342+ return [
343+ SupersetError (
344+ "Hostname is required" ,
345+ SupersetErrorType .CONNECTION_MISSING_PARAMETERS_ERROR ,
346+ ErrorLevel .WARNING ,
347+ {"missing" : ["host" ]},
348+ )
349+ ]
350+ if not is_hostname_valid (host ):
351+ return [
352+ SupersetError (
353+ "The hostname provided can't be resolved." ,
354+ SupersetErrorType .CONNECTION_INVALID_HOSTNAME_ERROR ,
355+ ErrorLevel .ERROR ,
356+ {"invalid" : ["host" ]},
357+ )
358+ ]
359+ port = parameters .get ("port" )
360+ if port is None :
361+ port = default_port ("http" , parameters .get ("encryption" , False ))
362+ try :
363+ port = int (port )
364+ except (ValueError , TypeError ):
365+ port = - 1
366+ if port <= 0 or port >= 65535 :
367+ return [
368+ SupersetError (
369+ "Port must be a valid integer between 0 and 65535 (inclusive)." ,
370+ SupersetErrorType .CONNECTION_INVALID_PORT_ERROR ,
371+ ErrorLevel .ERROR ,
372+ {"invalid" : ["port" ]},
373+ )
374+ ]
375+ if not is_port_open (host , port ):
376+ return [
377+ SupersetError (
378+ "The port is closed." ,
379+ SupersetErrorType .CONNECTION_PORT_CLOSED_ERROR ,
380+ ErrorLevel .ERROR ,
381+ {"invalid" : ["port" ]},
382+ )
383+ ]
384+ return []
385+
386+ @staticmethod
387+ def _mutate_label (label : str ) -> str :
388+ """
389+ Suffix with the first six characters from the md5 of the label to avoid
390+ collisions with original column names
391+
392+ :param label: Expected expression label
393+ :return: Conditionally mutated label
394+ """
395+ return f"{ label } _{ md5_sha_from_str (label )[:6 ]} "
0 commit comments