Skip to content

Commit f0fe023

Browse files
Make a best effort attempt to initialise all Databricks globals (#562)
## Changes We only initialise dbutils locally when using `from databricks.sdk.runtime import *`. Users in the webui are guided to use this import in all library code. <img width="812" alt="image" src="https://github.com/databricks/databricks-sdk-py/assets/88345179/3f28bc3c-0ba9-41ac-b990-3c4f5bf138aa"> The local (for people outside DBR) solution so far was to initialise spark manually. But this can be tedious for deeply nested libraries (which is the reason this import was introduced in the first place). Now, we make a best effort attempt to initialise maximum number of globals locally, so that users can build and debug libraries using databricks connect. ## Tests * integration test - [ ] `make test` run locally - [x] `make fmt` applied - [ ] relevant integration tests applied --------- Signed-off-by: Kartik Gupta <[email protected]> Co-authored-by: Miles Yucht <[email protected]>
1 parent 5255760 commit f0fe023

File tree

9 files changed

+176
-66
lines changed

9 files changed

+176
-66
lines changed

.vscode/settings.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,9 @@
33
"tests"
44
],
55
"python.testing.unittestEnabled": false,
6-
"python.testing.pytestEnabled": true
6+
"python.testing.pytestEnabled": true,
7+
"python.envFile": "${workspaceFolder}/.databricks/.databricks.env",
8+
"databricks.python.envFile": "${workspaceFolder}/.env",
9+
"jupyter.interactiveWindow.cellMarker.codeRegex": "^# COMMAND ----------|^# Databricks notebook source|^(#\\s*%%|#\\s*\\<codecell\\>|#\\s*In\\[\\d*?\\]|#\\s*In\\[ \\])",
10+
"jupyter.interactiveWindow.cellMarker.default": "# COMMAND ----------"
711
}

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ test:
2424
pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests
2525

2626
integration:
27-
pytest -n auto -m 'integration and not benchmark' --cov=databricks --cov-report html tests
27+
pytest -n auto -m 'integration and not benchmark' --dist loadgroup --cov=databricks --cov-report html tests
2828

2929
benchmark:
3030
pytest -m 'benchmark' tests

databricks/sdk/_widgets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ def get(self, name: str):
1313
def _get(self, name: str) -> str:
1414
pass
1515

16-
def getArgument(self, name: str, default_value: typing.Optional[str] = None):
16+
def getArgument(self, name: str, defaultValue: typing.Optional[str] = None):
1717
try:
1818
return self.get(name)
1919
except Exception:
20-
return default_value
20+
return defaultValue
2121

2222
def remove(self, name: str):
2323
self._remove(name)

databricks/sdk/runtime/__init__.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Dict, Union
4+
from typing import Dict, Optional, Union, cast
55

66
logger = logging.getLogger('databricks.sdk')
77
is_local_implementation = True
@@ -86,23 +86,97 @@ def inner() -> Dict[str, str]:
8686
_globals[var] = userNamespaceGlobals[var]
8787
is_local_implementation = False
8888
except ImportError:
89-
from typing import cast
90-
9189
# OSS implementation
9290
is_local_implementation = True
9391

94-
from databricks.sdk.dbutils import RemoteDbUtils
92+
for var in dbruntime_objects:
93+
globals()[var] = None
9594

96-
from . import dbutils_stub
95+
# The next few try-except blocks are for initialising globals in a best effort
96+
# mannaer. We separate them to try to get as many of them working as possible
97+
try:
98+
# We expect this to fail and only do this for providing types
99+
from pyspark.sql.context import SQLContext
100+
sqlContext: SQLContext = None # type: ignore
101+
table = sqlContext.table
102+
except Exception as e:
103+
logging.debug(f"Failed to initialize globals 'sqlContext' and 'table', continuing. Cause: {e}")
97104

98-
dbutils_type = Union[dbutils_stub.dbutils, RemoteDbUtils]
105+
try:
106+
from pyspark.sql.functions import udf # type: ignore
107+
except ImportError as e:
108+
logging.debug(f"Failed to initialise udf global: {e}")
99109

100110
try:
101-
from .stub import *
102-
except (ImportError, NameError):
103-
# this assumes that all environment variables are set
104-
dbutils = RemoteDbUtils()
111+
from databricks.connect import DatabricksSession # type: ignore
112+
spark = DatabricksSession.builder.getOrCreate()
113+
sql = spark.sql # type: ignore
114+
except Exception as e:
115+
# We are ignoring all failures here because user might want to initialize
116+
# spark session themselves and we don't want to interfere with that
117+
logging.debug(f"Failed to initialize globals 'spark' and 'sql', continuing. Cause: {e}")
105118

119+
try:
120+
# We expect this to fail locally since dbconnect does not support sparkcontext. This is just for typing
121+
sc = spark.sparkContext
122+
except Exception as e:
123+
logging.debug(f"Failed to initialize global 'sc', continuing. Cause: {e}")
124+
125+
def display(input=None, *args, **kwargs) -> None: # type: ignore
126+
"""
127+
Display plots or data.
128+
Display plot:
129+
- display() # no-op
130+
- display(matplotlib.figure.Figure)
131+
Display dataset:
132+
- display(spark.DataFrame)
133+
- display(list) # if list can be converted to DataFrame, e.g., list of named tuples
134+
- display(pandas.DataFrame)
135+
- display(koalas.DataFrame)
136+
- display(pyspark.pandas.DataFrame)
137+
Display any other value that has a _repr_html_() method
138+
For Spark 2.0 and 2.1:
139+
- display(DataFrame, streamName='optional', trigger=optional pyspark.sql.streaming.Trigger,
140+
checkpointLocation='optional')
141+
For Spark 2.2+:
142+
- display(DataFrame, streamName='optional', trigger=optional interval like '1 second',
143+
checkpointLocation='optional')
144+
"""
145+
# Import inside the function so that imports are only triggered on usage.
146+
from IPython import display as IPDisplay
147+
return IPDisplay.display(input, *args, **kwargs) # type: ignore
148+
149+
def displayHTML(html) -> None: # type: ignore
150+
"""
151+
Display HTML data.
152+
Parameters
153+
----------
154+
data : URL or HTML string
155+
If data is a URL, display the resource at that URL, the resource is loaded dynamically by the browser.
156+
Otherwise data should be the HTML to be displayed.
157+
See also:
158+
IPython.display.HTML
159+
IPython.display.display_html
160+
"""
161+
# Import inside the function so that imports are only triggered on usage.
162+
from IPython import display as IPDisplay
163+
return IPDisplay.display_html(html, raw=True) # type: ignore
164+
165+
# We want to propagate the error in initialising dbutils because this is a core
166+
# functionality of the sdk
167+
from databricks.sdk.dbutils import RemoteDbUtils
168+
169+
from . import dbutils_stub
170+
dbutils_type = Union[dbutils_stub.dbutils, RemoteDbUtils]
171+
172+
dbutils = RemoteDbUtils()
106173
dbutils = cast(dbutils_type, dbutils)
107174

108-
__all__ = ['dbutils'] if is_local_implementation else dbruntime_objects
175+
# We do this to prevent importing widgets implementation prematurely
176+
# The widget import should prompt users to use the implementation
177+
# which has ipywidget support.
178+
def getArgument(name: str, defaultValue: Optional[str] = None):
179+
return dbutils.widgets.getArgument(name, defaultValue)
180+
181+
182+
__all__ = dbruntime_objects

databricks/sdk/runtime/dbutils_stub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def get(name: str) -> str:
288288
...
289289

290290
@staticmethod
291-
def getArgument(name: str, defaultValue: typing.Optional[str] = None) -> str:
291+
def getArgument(name: str, defaultValue: typing.Optional[str] = None) -> str | None:
292292
"""Returns the current value of a widget with give name.
293293
:param name: Name of the argument to be accessed
294294
:param defaultValue: (Deprecated) default value

databricks/sdk/runtime/stub.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
install_requires=["requests>=2.28.1,<3", "google-auth~=2.0"],
1717
extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock",
1818
"yapf", "pycodestyle", "autoflake", "isort", "wheel",
19-
"ipython", "ipywidgets", "requests-mock", "pyfakefs"],
19+
"ipython", "ipywidgets", "requests-mock", "pyfakefs",
20+
"databricks-connect"],
2021
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]},
2122
author="Serge Smertin",
2223
author_email="[email protected]",

tests/integration/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def ucws(env_or_skip) -> WorkspaceClient:
8181
@pytest.fixture(scope='session')
8282
def env_or_skip():
8383

84-
def inner(var) -> str:
84+
def inner(var: str) -> str:
8585
if var not in os.environ:
8686
pytest.skip(f'Environment variable {var} is missing')
8787
return os.environ[var]
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytest
2+
3+
DBCONNECT_DBR_CLIENT = {"13.3": "13.3.3", "14.3": "14.3.1", }
4+
5+
6+
def reload_modules(name: str):
7+
"""
8+
Reloads the specified module. This is useful when testing Databricks Connect, since both
9+
the `databricks.connect` and `databricks.sdk.runtime` modules are stateful, and we need
10+
to reload these modules to reset the state cache between test runs.
11+
"""
12+
13+
import importlib
14+
import sys
15+
16+
v = sys.modules.get(name)
17+
if v is None:
18+
return
19+
try:
20+
print(f"Reloading {name}")
21+
importlib.reload(v)
22+
except Exception as e:
23+
print(f"Failed to reload {name}: {e}")
24+
25+
26+
@pytest.fixture(scope="function")
27+
def restorable_env():
28+
import os
29+
current_env = os.environ.copy()
30+
yield
31+
for k, v in os.environ.items():
32+
if k not in current_env:
33+
del os.environ[k]
34+
elif v != current_env[k]:
35+
os.environ[k] = current_env[k]
36+
37+
38+
@pytest.fixture(params=list(DBCONNECT_DBR_CLIENT.keys()))
39+
def setup_dbconnect_test(request, env_or_skip, restorable_env):
40+
dbr = request.param
41+
assert dbr in DBCONNECT_DBR_CLIENT, f"Unsupported Databricks Runtime version {dbr}. Please update DBCONNECT_DBR_CLIENT."
42+
43+
import os
44+
os.environ["DATABRICKS_CLUSTER_ID"] = env_or_skip(
45+
f"TEST_DBR_{dbr.replace('.', '_')}_DBCONNECT_CLUSTER_ID")
46+
47+
import subprocess
48+
import sys
49+
lib = f"databricks-connect=={DBCONNECT_DBR_CLIENT[dbr]}"
50+
subprocess.check_call([sys.executable, "-m", "pip", "install", lib])
51+
52+
yield
53+
54+
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "databricks-connect"])
55+
56+
57+
@pytest.mark.xdist_group(name="databricks-connect")
58+
def test_dbconnect_initialisation(w, setup_dbconnect_test):
59+
reload_modules("databricks.connect")
60+
from databricks.connect import DatabricksSession
61+
62+
spark = DatabricksSession.builder.getOrCreate()
63+
assert spark.sql("SELECT 1").collect()[0][0] == 1
64+
65+
66+
@pytest.mark.xdist_group(name="databricks-connect")
67+
def test_dbconnect_runtime_import(w, setup_dbconnect_test):
68+
reload_modules("databricks.sdk.runtime")
69+
from databricks.sdk.runtime import spark
70+
71+
assert spark.sql("SELECT 1").collect()[0][0] == 1
72+
73+
74+
@pytest.mark.xdist_group(name="databricks-connect")
75+
def test_dbconnect_runtime_import_no_error_if_doesnt_exist(w):
76+
reload_modules("databricks.sdk.runtime")
77+
from databricks.sdk.runtime import spark
78+
79+
assert spark is None

0 commit comments

Comments
 (0)