-
Notifications
You must be signed in to change notification settings - Fork 101
Expand file tree
/
Copy pathgeneric.py
More file actions
471 lines (409 loc) · 18.7 KB
/
generic.py
File metadata and controls
471 lines (409 loc) · 18.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
import datetime
import json
import logging
from collections.abc import Callable, Iterable, Iterator
from dataclasses import dataclass
from datetime import timedelta
from functools import partial
from databricks.labs.blueprint.limiter import rate_limited
from databricks.labs.blueprint.parallel import ManyError, Threads
from databricks.labs.lsql.backends import SqlBackend
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import (
DeadlineExceeded,
InternalError,
InvalidParameterValue,
NotFound,
PermissionDenied,
ResourceConflict,
TemporarilyUnavailable,
)
from databricks.sdk.retries import retried
from databricks.sdk.service import iam, ml
from databricks.sdk.service.iam import PermissionLevel
from databricks.labs.ucx.framework.crawlers import CrawlerBase
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.workspace_access.base import AclSupport, Permissions, StaticListing
from databricks.labs.ucx.workspace_access.groups import MigrationState
logger = logging.getLogger(__name__)
@dataclass
class GenericPermissionsInfo:
object_id: str
request_type: str
@dataclass
class WorkspaceObjectInfo:
path: str
object_type: str | None = None
object_id: str | None = None
language: str | None = None
class Listing:
def __init__(self, func: Callable[..., Iterable], id_attribute: str, object_type: str):
self._func = func
self._id_attribute = id_attribute
self._object_type = object_type
def object_types(self) -> set[str]:
return {self._object_type}
def __iter__(self):
started = datetime.datetime.now()
try:
for item in self._func():
yield GenericPermissionsInfo(getattr(item, self._id_attribute), self._object_type)
except (NotFound, InternalError) as e:
logger.warning(f"Listing {self._object_type} failed", exc_info=e)
since = datetime.datetime.now() - started
logger.info(f"Listed {self._object_type} in {since}")
def __repr__(self):
return f"Listing({self._object_type} via {self._func.__qualname__})"
class GenericPermissionsSupport(AclSupport):
def __init__(
self,
ws: WorkspaceClient,
listings: list[Listing],
verify_timeout: timedelta | None = timedelta(minutes=1),
# this parameter is for testing scenarios only - [{object_type}:{object_id}]
# it will use StaticListing class to return only object ids that has the same object type
include_object_permissions: list[str] | None = None,
):
self._ws = ws
self._listings = listings
self._verify_timeout = verify_timeout
self._include_object_permissions = include_object_permissions
def get_crawler_tasks(self):
if self._include_object_permissions:
for item in StaticListing(self._include_object_permissions, self.object_types()):
yield partial(self._crawler_task, item.object_type, item.object_id)
return
for listing in self._listings:
for info in listing:
yield partial(self._crawler_task, info.request_type, info.object_id)
def object_types(self) -> set[str]:
all_object_types = set()
for listing in self._listings:
for object_type in listing.object_types():
all_object_types.add(object_type)
return all_object_types
def get_apply_task(self, item: Permissions, migration_state: MigrationState):
if not self._is_item_relevant(item, migration_state):
return None
object_permissions = iam.ObjectPermissions.from_dict(json.loads(item.raw))
new_acl = self._prepare_new_acl(object_permissions, migration_state)
return partial(self._applier_task, item.object_type, item.object_id, new_acl)
@staticmethod
def _is_item_relevant(item: Permissions, migration_state: MigrationState) -> bool:
# passwords and tokens are represented on the workspace-level
if item.object_id in {"tokens", "passwords"}:
return True
object_permissions = iam.ObjectPermissions.from_dict(json.loads(item.raw))
assert object_permissions.access_control_list is not None
for acl in object_permissions.access_control_list:
if not acl.group_name:
continue
if migration_state.is_in_scope(acl.group_name):
return True
return False
@staticmethod
def _response_to_request(
acls: list[iam.AccessControlResponse] | None = None,
) -> list[iam.AccessControlRequest]:
results: list[iam.AccessControlRequest] = []
if not acls:
return results
for acl in acls:
if not acl.all_permissions:
continue
for permission in acl.all_permissions:
results.append(
iam.AccessControlRequest(
acl.group_name, permission.permission_level, acl.service_principal_name, acl.user_name
)
)
return results
@rate_limited(max_requests=100)
def _verify(self, object_type: str, object_id: str, acl: list[iam.AccessControlRequest]):
# in-flight check for the applied permissions
# the api might be inconsistent, therefore we need to check that the permissions were applied
remote_permission = self._safe_get_permissions(object_type, object_id)
if remote_permission:
remote_permission_as_request = self._response_to_request(remote_permission.access_control_list)
if all(elem in remote_permission_as_request for elem in acl):
return True
msg = (
f"Couldn't find permission for object type {object_type} with id {object_id}\n"
f"acl to be applied={acl}\n"
f"acl found in the object={remote_permission_as_request}\n"
)
raise NotFound(msg)
return False
def get_verify_task(self, item: Permissions) -> Callable[[], bool]:
acl = iam.ObjectPermissions.from_dict(json.loads(item.raw))
if not acl.access_control_list:
raise ValueError(
f"Access control list not present for object type " f"{item.object_type} and object id {item.object_id}"
)
permissions_as_request = self._response_to_request(acl.access_control_list)
return partial(self._verify, item.object_type, item.object_id, permissions_as_request)
@rate_limited(max_requests=30)
def _applier_task(self, object_type: str, object_id: str, acl: list[iam.AccessControlRequest]):
retryable_exceptions = [InternalError, NotFound, ResourceConflict, TemporarilyUnavailable, DeadlineExceeded]
update_retry_on_value_error = retried(
on=retryable_exceptions, timeout=self._verify_timeout # type: ignore[arg-type]
)
update_retried_check = update_retry_on_value_error(self._safe_update_permissions)
update_retried_check(object_type, object_id, acl)
retry_on_value_error = retried(on=retryable_exceptions, timeout=self._verify_timeout)
retried_check = retry_on_value_error(self._verify)
return retried_check(object_type, object_id, acl)
@rate_limited(max_requests=100)
def _crawler_task(self, object_type: str, object_id: str) -> Permissions | None:
objects_with_owner_permission = ["jobs", "pipelines"]
permissions = self._safe_get_permissions(object_type, object_id)
if not permissions:
logger.warning(f"Object {object_type} {object_id} doesn't have any permissions")
return None
if not self._object_have_owner(permissions) and object_type in objects_with_owner_permission:
logger.warning(
f"Object {object_type} {object_id} doesn't have any Owner and cannot be migrated "
f"to account level groups, consider setting a new owner or deleting this object"
)
return None
return Permissions(
object_id=object_id,
object_type=object_type,
raw=json.dumps(permissions.as_dict()),
)
def _object_have_owner(self, permissions: iam.ObjectPermissions | None):
if not permissions:
return False
if not permissions.access_control_list:
return False
for acl in permissions.access_control_list:
if not acl.all_permissions:
continue
for perm in acl.all_permissions:
if perm.permission_level == PermissionLevel.IS_OWNER:
return True
return False
def _load_as_request(self, object_type: str, object_id: str) -> list[iam.AccessControlRequest]:
loaded = self._safe_get_permissions(object_type, object_id)
if loaded is None:
return []
acl: list[iam.AccessControlRequest] = []
if not loaded.access_control_list:
return acl
for access_control in loaded.access_control_list:
if not access_control.all_permissions:
continue
for permission in access_control.all_permissions:
if permission.inherited:
continue
acl.append(
iam.AccessControlRequest(
permission_level=permission.permission_level,
service_principal_name=access_control.service_principal_name,
group_name=access_control.group_name,
user_name=access_control.user_name,
)
)
# sort to return deterministic results
return sorted(acl, key=lambda v: f"{v.group_name}:{v.user_name}:{v.service_principal_name}")
def load_as_dict(self, object_type: str, object_id: str) -> dict[str, iam.PermissionLevel]:
result = {}
for acl in self._load_as_request(object_type, object_id):
if not acl.permission_level:
continue
result[self._key_for_acl_dict(acl)] = acl.permission_level
return result
@staticmethod
def _key_for_acl_dict(acl: iam.AccessControlRequest) -> str:
if acl.group_name is not None:
return acl.group_name
if acl.user_name is not None:
return acl.user_name
if acl.service_principal_name is not None:
return acl.service_principal_name
return "UNKNOWN"
# TODO remove after ES-892977 is fixed
@retried(on=[InternalError], timeout=timedelta(minutes=5))
def _safe_get_permissions(self, object_type: str, object_id: str) -> iam.ObjectPermissions | None:
try:
return self._ws.permissions.get(object_type, object_id)
except PermissionDenied:
logger.warning(f"permission denied: {object_type} {object_id}")
return None
except NotFound:
logger.warning(f"removed on backend: {object_type} {object_id}")
return None
except InvalidParameterValue:
logger.warning(f"jobs or cluster removed on backend: {object_type} {object_id}")
return None
def _safe_update_permissions(
self, object_type: str, object_id: str, acl: list[iam.AccessControlRequest]
) -> iam.ObjectPermissions | None:
try:
return self._ws.permissions.update(object_type, object_id, access_control_list=acl)
except PermissionDenied:
logger.warning(f"permission denied: {object_type} {object_id}")
return None
except NotFound:
logger.warning(f"removed on backend: {object_type} {object_id}")
return None
except InvalidParameterValue:
logger.warning(f"jobs or cluster removed on backend: {object_type} {object_id}")
return None
def _prepare_new_acl(
self, permissions: iam.ObjectPermissions, migration_state: MigrationState
) -> list[iam.AccessControlRequest]:
_acl = permissions.access_control_list
acl_requests: list[iam.AccessControlRequest] = []
if not _acl:
return acl_requests
coord = f"{permissions.object_type}/{permissions.object_id}"
for _item in _acl:
if not _item.group_name:
continue
if not migration_state.is_in_scope(_item.group_name):
logger.debug(f"Skipping {_item} for {coord} because it is not in scope")
continue
new_group_name = migration_state.get_target_principal(_item.group_name)
if new_group_name is None:
logger.debug(f"Skipping {_item.group_name} for {coord} because it has no target principal")
continue
if not _item.all_permissions:
continue
for permission in _item.all_permissions:
if permission.inherited:
continue
acl_requests.append(
iam.AccessControlRequest(
group_name=new_group_name,
service_principal_name=_item.service_principal_name,
user_name=_item.user_name,
permission_level=permission.permission_level,
)
)
return acl_requests
def __repr__(self):
return f"GenericPermissionsSupport({self._listings})"
class WorkspaceListing(Listing, CrawlerBase[WorkspaceObjectInfo]):
def __init__(
self,
ws: WorkspaceClient,
sql_backend: SqlBackend,
inventory_database: str,
num_threads=20,
start_path: str | None = "/",
):
Listing.__init__(self, lambda: [], "_", "_")
CrawlerBase.__init__(
self,
sql_backend=sql_backend,
catalog="hive_metastore",
schema=inventory_database,
table="workspace_objects",
klass=WorkspaceObjectInfo,
)
self._ws = ws
self._num_threads = num_threads
self._start_path = start_path
self._sql_backend = sql_backend
self._inventory_database = inventory_database
def _crawl(self) -> Iterable[WorkspaceObjectInfo]:
# pylint: disable-next=import-outside-toplevel,redefined-outer-name
from databricks.labs.ucx.workspace_access.listing import WorkspaceListing
ws_listing = WorkspaceListing(self._ws, num_threads=self._num_threads, with_directories=False)
for obj in ws_listing.walk(self._start_path):
if obj is None or obj.object_type is None:
continue
raw = obj.as_dict()
yield WorkspaceObjectInfo(
object_type=raw.get("object_type", None),
object_id=str(raw.get("object_id", None)),
path=raw.get("path", None),
language=raw.get("language", None),
)
def _try_fetch(self) -> Iterable[WorkspaceObjectInfo]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield WorkspaceObjectInfo(
path=row["path"], object_type=row["object_type"], object_id=row["object_id"], language=row["language"]
)
def object_types(self) -> set[str]:
return {"notebooks", "directories", "repos", "files"}
@staticmethod
def _convert_object_type_to_request_type(_object: WorkspaceObjectInfo) -> str | None:
match _object.object_type:
case "NOTEBOOK":
return "notebooks"
case "DIRECTORY":
return "directories"
case "LIBRARY":
return None
case "REPO":
return "repos"
case "FILE":
return "files"
# silent handler for experiments - they'll be inventoried by the experiments manager
return None
def __iter__(self):
for _object in self.snapshot():
request_type = self._convert_object_type_to_request_type(_object)
if not request_type:
continue
assert _object.object_id is not None
yield GenericPermissionsInfo(str(_object.object_id), request_type)
def __repr__(self):
return f"WorkspaceListing(start_path={self._start_path})"
def models_listing(ws: WorkspaceClient, num_threads: int | None) -> Callable[[], Iterator[ml.ModelDatabricks]]:
def inner() -> Iterator[ml.ModelDatabricks]:
tasks = []
for model in ws.model_registry.list_models():
tasks.append(partial(ws.model_registry.get_model, name=model.name))
models, errors = Threads.gather("listing model ids", tasks, num_threads)
if len(errors) > 0:
raise ManyError(errors)
for model_response in models:
if not model_response.registered_model_databricks:
continue
yield model_response.registered_model_databricks
return inner
def experiments_listing(ws: WorkspaceClient):
def _get_repo_nb_tag(experiment):
repo_nb_tag = []
for tag in experiment.tags:
if tag.key == "mlflow.experiment.sourceType" and tag.value == "REPO_NOTEBOOK":
repo_nb_tag.append(tag)
return repo_nb_tag
def inner() -> Iterator[ml.Experiment]:
for experiment in ws.experiments.list_experiments():
# We filter-out notebook-based experiments, because they are covered by notebooks listing in
# workspace-based notebook experiment
if experiment.tags:
nb_tag = [t for t in experiment.tags if t.key == "mlflow.experimentType" and t.value == "NOTEBOOK"]
# repo-based notebook experiment
repo_nb_tag = _get_repo_nb_tag(experiment)
if nb_tag or repo_nb_tag:
continue
yield experiment
return inner
def feature_store_listing(ws: WorkspaceClient):
def inner() -> list[GenericPermissionsInfo]:
feature_tables = []
token = None
while True:
result = ws.api_client.do(
"GET", "/api/2.0/feature-store/feature-tables/search", query={"page_token": token, "max_results": 200}
)
assert isinstance(result, dict)
for table in result.get("feature_tables", []):
feature_tables.append(GenericPermissionsInfo(table["id"], "feature-tables"))
if "next_page_token" not in result:
break
token = result["next_page_token"] # type: ignore[index]
return feature_tables
return inner
def feature_tables_root_page():
return [GenericPermissionsInfo("/root", "feature-tables")]
def models_root_page():
return [GenericPermissionsInfo("/root", "registered-models")]
def tokens_and_passwords():
for _value in ("tokens", "passwords"):
yield GenericPermissionsInfo(_value, "authorization")