Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 23 additions & 32 deletions airflow/api_connexion/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Any

import werkzeug
from connexion import FlaskApi, ProblemException, problem
from connexion import ProblemException, problem

from airflow.utils.docs import get_docs_url

if TYPE_CHECKING:
import flask
from connexion.lifecycle import ConnexionRequest, ConnexionResponse

doc_link = get_docs_url("stable-rest-api-ref.html")

Expand All @@ -40,37 +39,29 @@
}


def common_error_handler(exception: BaseException) -> flask.Response:
def problem_error_handler(_request: ConnexionRequest, exception: ProblemException) -> ConnexionResponse:
"""Use to capture connexion exceptions and add link to the type field."""
if isinstance(exception, ProblemException):
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=link,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
return problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=link,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
if not isinstance(exception, werkzeug.exceptions.HTTPException):
exception = werkzeug.exceptions.InternalServerError()

response = problem(title=exception.name, detail=exception.description, status=exception.code)

return FlaskApi.get_response(response)
return problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)


class NotFound(ProblemException):
Expand Down
3 changes: 2 additions & 1 deletion airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
import connexion
from flask import Blueprint
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -79,7 +80,7 @@ def get_cli_commands() -> list[CLICommand]:
"""
return []

def get_api_endpoints(self) -> None | Blueprint:
def get_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None | Blueprint:
"""Return API endpoint(s) definition for the auth manager."""
return None

Expand Down
17 changes: 9 additions & 8 deletions airflow/cli/commands/internal_api_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from tempfile import gettempdir
from time import sleep

import connexion
import psutil
from flask import Flask
from flask_appbuilder import SQLA
from flask_caching import Cache
from flask_wtf.csrf import CSRFProtect
Expand All @@ -54,7 +54,7 @@
from airflow.www.extensions.init_views import init_api_internal, init_error_handlers

log = logging.getLogger(__name__)
app: Flask | None = None
app: connexion.FlaskApp | None = None


@cli_utils.action_cli
Expand All @@ -73,8 +73,8 @@ def internal_api(args):
log.info(f"Starting the Internal API server on port {args.port} and host {args.hostname}.")
app = create_app(testing=conf.getboolean("core", "unit_test_mode"))
app.run(
debug=True, # nosec
use_reloader=not app.config["TESTING"],
log_level="debug",
# reload=not app.app.config["TESTING"],
port=args.port,
host=args.hostname,
)
Expand All @@ -101,7 +101,7 @@ def internal_api(args):
"--workers",
str(num_workers),
"--worker-class",
str(args.workerclass),
"uvicorn.workers.UvicornWorker",
"--timeout",
str(worker_timeout),
"--bind",
Expand Down Expand Up @@ -195,7 +195,8 @@ def start_and_monitor_gunicorn(args):

def create_app(config=None, testing=False):
"""Create a new instance of Airflow Internal API app."""
flask_app = Flask(__name__)
connexion_app = connexion.FlaskApp(__name__)
flask_app = connexion_app.app

flask_app.config["APP_NAME"] = "Airflow Internal API"
flask_app.config["TESTING"] = testing
Expand Down Expand Up @@ -240,11 +241,11 @@ def create_app(config=None, testing=False):

with flask_app.app_context():
init_error_handlers(flask_app)
init_api_internal(flask_app, standalone_api=True)
init_api_internal(connexion_app, standalone_api=True)

init_jinja_globals(flask_app)
init_xframe_protection(flask_app)
return flask_app
return connexion_app


def cached_app(config=None, testing=False):
Expand Down
8 changes: 4 additions & 4 deletions airflow/cli/commands/webserver_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,11 @@ def webserver(args):
print(f"Starting the web server on port {args.port} and host {args.hostname}.")
app = create_app(testing=conf.getboolean("core", "unit_test_mode"))
app.run(
debug=True,
use_reloader=not app.config["TESTING"],
log_level="debug",
port=args.port,
host=args.hostname,
ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None,
ssl_keyfile=ssl_key if ssl_cert and ssl_key else None,
ssl_certfile=ssl_cert if ssl_cert and ssl_key else None,
)
else:
print(
Expand All @@ -383,7 +383,7 @@ def webserver(args):
"--workers",
str(num_workers),
"--worker-class",
str(args.workerclass),
"uvicorn.workers.UvicornWorker",
"--timeout",
str(worker_timeout),
"--bind",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@

def remap_permissions():
"""Apply Map Airflow permissions."""
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder
for old, new in mapping.items():
(old_resource_name, old_action_name) = old
old_permission = appbuilder.sm.get_permission(old_action_name, old_resource_name)
Expand All @@ -313,7 +313,7 @@ def remap_permissions():

def undo_remap_permissions():
"""Unapply Map Airflow permissions"""
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder
for old, new in mapping.items():
(new_resource_name, new_action_name) = new[0]
new_permission = appbuilder.sm.get_permission(new_action_name, new_resource_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def upgrade():
log = logging.getLogger()
handlers = log.handlers[:]

appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder
roles_to_modify = [role for role in appbuilder.sm.get_all_roles() if role.name in ["User", "Viewer"]]
can_read_on_config_perm = appbuilder.sm.get_permission(
permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG
Expand All @@ -59,7 +59,7 @@ def upgrade():

def downgrade():
"""Add can_read action on config resource for User and Viewer role"""
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder
roles_to_modify = [role for role in appbuilder.sm.get_all_roles() if role.name in ["User", "Viewer"]]
can_read_on_config_perm = appbuilder.sm.get_permission(
permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@

def remap_permissions():
"""Apply Map Airflow permissions."""
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder
for old, new in mapping.items():
(old_resource_name, old_action_name) = old
old_permission = appbuilder.sm.get_permission(old_action_name, old_resource_name)
Expand All @@ -165,7 +165,7 @@ def remap_permissions():

def undo_remap_permissions():
"""Unapply Map Airflow permissions"""
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder
for old, new in mapping.items():
(new_resource_name, new_action_name) = new[0]
new_permission = appbuilder.sm.get_permission(new_action_name, new_resource_name)
Expand Down
23 changes: 14 additions & 9 deletions airflow/providers/fab/auth_manager/fab_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Container

from connexion import FlaskApi
from connexion.options import SwaggerUIOptions
from flask import Blueprint, url_for
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
Expand Down Expand Up @@ -83,9 +83,11 @@
)
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.yaml import safe_load
from airflow.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver
from airflow.www.extensions.init_views import _LazyResolver

if TYPE_CHECKING:
import connexion

from airflow.auth.managers.models.base_user import BaseUser
from airflow.cli.cli_config import (
CLICommand,
Expand Down Expand Up @@ -147,21 +149,24 @@ def get_cli_commands() -> list[CLICommand]:
SYNC_PERM_COMMAND, # not in a command group
]

def get_api_endpoints(self) -> None | Blueprint:
def get_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None | Blueprint:
folder = Path(__file__).parents[0].resolve() # this is airflow/auth/managers/fab/
with folder.joinpath("openapi", "v1.yaml").open() as f:
specification = safe_load(f)
return FlaskApi(

swagger_ui_options = SwaggerUIOptions(
swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True),
)

api = connexion_app.add_api(
specification=specification,
resolver=_LazyResolver(),
base_path="/auth/fab/v1",
options={
"swagger_ui": conf.getboolean("webserver", "enable_swagger_ui", fallback=True),
},
swagger_ui_options=swagger_ui_options,
strict_validation=True,
validate_responses=True,
validator_map={"body": _CustomErrorRequestBodyValidator},
).blueprint
)
return api.blueprint if api else None

def get_user_display_name(self) -> str:
"""Return the user's display name associated to the user in session."""
Expand Down
3 changes: 2 additions & 1 deletion airflow/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class AirflowJsonProvider(JSONProvider):
def dumps(self, obj, **kwargs):
kwargs.setdefault("ensure_ascii", self.ensure_ascii)
kwargs.setdefault("sort_keys", self.sort_keys)
return json.dumps(obj, **kwargs, cls=WebEncoder)
kwargs.setdefault("cls", WebEncoder)
return json.dumps(obj, **kwargs)

def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)
Expand Down
30 changes: 22 additions & 8 deletions airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import warnings
from datetime import timedelta

from flask import Flask
import connexion
from flask_appbuilder import SQLA
from flask_wtf.csrf import CSRFProtect
from markupsafe import Markup
from sqlalchemy.engine.url import make_url
from starlette.middleware.cors import CORSMiddleware

from airflow import settings
from airflow.api_internal.internal_api_call import InternalApiConfig
Expand Down Expand Up @@ -61,7 +62,7 @@
)
from airflow.www.extensions.init_wsgi_middlewares import init_wsgi_middleware

app: Flask | None = None
app: connexion.FlaskApp | None = None

# Initializes at the module level, so plugins can access it.
# See: /docs/plugins.rst
Expand All @@ -70,7 +71,18 @@

def create_app(config=None, testing=False):
"""Create a new instance of Airflow WWW app."""
flask_app = Flask(__name__)
connexion_app = connexion.FlaskApp(__name__)

connexion_app.add_middleware(
CORSMiddleware,
connexion.middleware.MiddlewarePosition.BEFORE_ROUTING,
allow_origins=conf.get("api", "access_control_allow_origins"),
allow_credentials=True,
allow_methods=conf.get("api", "access_control_allow_methods"),
allow_headers=conf.get("api", "access_control_allow_headers"),
)

flask_app = connexion_app.app
flask_app.secret_key = conf.get("webserver", "SECRET_KEY")

flask_app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(minutes=settings.get_session_lifetime_config())
Expand Down Expand Up @@ -158,22 +170,24 @@ def create_app(config=None, testing=False):
init_appbuilder_links(flask_app)
init_plugins(flask_app)
init_error_handlers(flask_app)
init_api_connexion(flask_app)
init_api_connexion(connexion_app)
if conf.getboolean("webserver", "run_internal_api", fallback=False):
if not _ENABLE_AIP_44:
raise RuntimeError("The AIP_44 is not enabled so you cannot use it.")
init_api_internal(flask_app)
init_api_internal(connexion_app)
init_api_experimental(flask_app)
init_api_auth_provider(flask_app)
init_api_error_handlers(flask_app) # needs to be after all api inits to let them add their path first
init_api_auth_provider(connexion_app)
init_api_error_handlers(
connexion_app
) # needs to be after all api inits to let them add their path first

get_auth_manager().init()

init_jinja_globals(flask_app)
init_xframe_protection(flask_app)
init_airflow_session_interface(flask_app)
init_check_user_active(flask_app)
return flask_app
return connexion_app


def cached_app(config=None, testing=False):
Expand Down
Loading