Skip to content

Commit 4fbb211

Browse files
feat: add headless auth (#508)
1 parent 1958c25 commit 4fbb211

9 files changed

Lines changed: 147 additions & 88 deletions

File tree

safety/auth/cli.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def render_successful_login(auth: Auth,
9696

9797

9898
@auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP)
99-
def login(ctx: typer.Context):
99+
def login(ctx: typer.Context, headless: bool = False):
100100
"""
101101
Authenticate Safety CLI with your safetycli.com account using your default browser.
102102
"""
@@ -105,29 +105,38 @@ def login(ctx: typer.Context):
105105
fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED)
106106

107107
console.print()
108-
brief_msg: str = "Redirecting your browser to log in; once authenticated, " \
109-
"return here to start using Safety"
110108

111-
uri, initial_state = get_authorization_data(client=ctx.obj.auth.client,
112-
code_verifier=ctx.obj.auth.code_verifier,
113-
organization=ctx.obj.auth.org)
109+
info = None
114110

115-
if ctx.obj.auth.org:
111+
brief_msg: str = "Redirecting your browser to log in; once authenticated, " \
112+
"return here to start using Safety"
113+
114+
if ctx.obj.auth.org:
116115
console.print(f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] " \
117116
"organization.")
118-
117+
118+
if headless:
119+
brief_msg = "Running in headless mode. Please copy and open the following URL in a browser"
120+
121+
122+
uri, initial_state = get_authorization_data(client=ctx.obj.auth.client,
123+
code_verifier=ctx.obj.auth.code_verifier,
124+
organization=ctx.obj.auth.org, headless=headless)
119125
click.secho(brief_msg)
120126
click.echo()
121127

122-
info = process_browser_callback(uri,
123-
initial_state=initial_state, ctx=ctx)
128+
info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx, headless=headless)
129+
124130

125131
if info:
126132
if info.get("email", None):
127133
organization = None
128134
if ctx.obj.auth.org and ctx.obj.auth.org.name:
129135
organization = ctx.obj.auth.org.name
130136
ctx.obj.auth.refresh_from(info)
137+
if headless:
138+
console.print()
139+
131140
render_successful_login(ctx.obj.auth, organization=organization)
132141

133142
console.print()
@@ -149,7 +158,7 @@ def login(ctx: typer.Context):
149158
else:
150159
msg += "Error logging into Safety."
151160

152-
msg += " Please try again, or use [bold]`safety auth –help`[/bold] " \
161+
msg += " Please try again, or use [bold]`safety auth -–help`[/bold] " \
153162
"for more information[/red]"
154163

155164
console.print(msg, emoji=True)

safety/auth/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33

44
from typing import Any, Dict, Optional, Tuple, Union
5+
from urllib.parse import urlencode
56

67
from authlib.oidc.core import CodeIDToken
78
from authlib.jose import jwt
@@ -17,9 +18,9 @@
1718

1819
def get_authorization_data(client, code_verifier: str,
1920
organization: Optional[Organization] = None,
20-
sign_up: bool = False, ensure_auth: bool = False) -> Tuple[str, str]:
21+
sign_up: bool = False, ensure_auth: bool = False, headless: bool = False) -> Tuple[str, str]:
2122

22-
kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth}
23+
kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth, 'headless': headless}
2324
if organization:
2425
kwargs['organization'] = organization.id
2526

safety/auth/server.py

Lines changed: 85 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import http.server
2+
import json
23
import logging
34
import socket
45
import sys
@@ -13,6 +14,8 @@
1314

1415
from safety.auth.constants import AUTH_SERVER_URL, CLI_AUTH_SUCCESS, CLI_LOGOUT_SUCCESS, HOST
1516
from safety.auth.main import save_auth_config
17+
from authlib.integrations.base_client.errors import OAuthError
18+
from rich.prompt import Prompt
1619

1720
LOG = logging.getLogger(__name__)
1821

@@ -33,40 +36,49 @@ def find_available_port():
3336

3437
return None
3538

39+
def auth_process(code: str, state: str, initial_state: str, code_verifier, client):
40+
err = None
41+
42+
if initial_state is None or initial_state != state:
43+
err = "The state parameter value provided does not match the expected " \
44+
"value. The state parameter is used to protect against Cross-Site " \
45+
"Request Forgery (CSRF) attacks. For security reasons, the " \
46+
"authorization process cannot proceed with an invalid state " \
47+
"parameter value. Please try again, ensuring that the state " \
48+
"parameter value provided in the authorization request matches " \
49+
"the value returned in the callback."
50+
51+
if err:
52+
click.secho(f'Error: {err}', fg='red')
53+
sys.exit(1)
54+
55+
try:
56+
tokens = client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token',
57+
code_verifier=code_verifier,
58+
client_id=client.client_id,
59+
grant_type='authorization_code', code=code)
60+
61+
save_auth_config(access_token=tokens['access_token'],
62+
id_token=tokens['id_token'],
63+
refresh_token=tokens['refresh_token'])
64+
return client.fetch_user_info()
65+
66+
except Exception as e:
67+
LOG.exception(e)
68+
sys.exit(1)
3669

3770
class CallbackHandler(http.server.BaseHTTPRequestHandler):
3871
def auth(self, code: str, state: str, err, error_description):
3972
initial_state = self.server.initial_state
4073
ctx = self.server.ctx
4174

42-
if initial_state is None or initial_state != state:
43-
err = "The state parameter value provided does not match the expected" \
44-
"value. The state parameter is used to protect against Cross-Site " \
45-
"Request Forgery (CSRF) attacks. For security reasons, the " \
46-
"authorization process cannot proceed with an invalid state " \
47-
"parameter value. Please try again, ensuring that the state " \
48-
"parameter value provided in the authorization request matches " \
49-
"the value returned in the callback."
50-
51-
if err:
52-
click.secho(f'Error: {err}', fg='red')
53-
sys.exit(1)
75+
result = auth_process(code=code,
76+
state=state,
77+
initial_state=initial_state,
78+
code_verifier=ctx.obj.auth.code_verifier,
79+
client=ctx.obj.auth.client)
5480

55-
try:
56-
tokens = ctx.obj.auth.client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token',
57-
code_verifier=ctx.obj.auth.code_verifier,
58-
client_id=ctx.obj.auth.client.client_id,
59-
grant_type='authorization_code', code=code)
60-
61-
save_auth_config(access_token=tokens['access_token'],
62-
id_token=tokens['id_token'],
63-
refresh_token=tokens['refresh_token'])
64-
self.server.callback = ctx.obj.auth.client.fetch_user_info()
65-
66-
except Exception as e:
67-
LOG.exception(e)
68-
sys.exit(1)
69-
81+
self.server.callback = result
7082
self.do_redirect(location=CLI_AUTH_SUCCESS, params={})
7183

7284
def logout(self):
@@ -132,27 +144,52 @@ def handle_timeout(self) -> None:
132144
sys.exit(1)
133145

134146
try:
135-
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
136-
server.initial_state = kwargs.get("initial_state", None)
137-
server.timeout = kwargs.get("timeout", 600)
138-
# timeout = kwargs.get("timeout", None)
139-
# timeout = float(timeout) if timeout else None
140-
server.ctx = kwargs.get("ctx", None)
141-
server_thread = threading.Thread(target=server.handle_request)
142-
server_thread.start()
143-
144-
target = f"{uri}&port={PORT}"
145-
console.print(f"If the browser does not automatically open in 5 seconds, " \
146-
"copy and paste this url into your browser: " \
147-
f"[link={target}]{target}[/link]")
148-
click.echo()
149-
150-
wait_msg = "waiting for browser authentication"
151-
152-
with console.status(wait_msg, spinner="bouncingBar"):
153-
time.sleep(2)
154-
click.launch(target)
155-
server_thread.join()
147+
headless = kwargs.get("headless", False)
148+
initial_state = kwargs.get("initial_state", None)
149+
ctx = kwargs.get("ctx", None)
150+
151+
message = "Copy and paste this url into your browser:"
152+
153+
154+
if not headless:
155+
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
156+
server.initial_state = initial_state
157+
server.timeout = kwargs.get("timeout", 600)
158+
server.ctx = ctx
159+
server_thread = threading.Thread(target=server.handle_request)
160+
server_thread.start()
161+
message = f"If the browser does not automatically open in 5 seconds, " \
162+
"copy and paste this url into your browser:"
163+
164+
target = uri if headless else f"{uri}&port={PORT}"
165+
console.print(f"{message} [link={target}]{target}[/link]")
166+
console.print()
167+
168+
if headless:
169+
170+
exchange_data = None
171+
while not exchange_data:
172+
auth_code_text = Prompt.ask("Paste the response here", default=None, console=console)
173+
try:
174+
exchange_data = json.loads(auth_code_text)
175+
state = exchange_data["state"]
176+
code = exchange_data["code"]
177+
except Exception as e:
178+
code = state = None
179+
180+
return auth_process(code=code,
181+
state=state,
182+
initial_state=initial_state,
183+
code_verifier=ctx.obj.auth.code_verifier,
184+
client=ctx.obj.auth.client)
185+
else:
186+
187+
wait_msg = "waiting for browser authentication"
188+
189+
with console.status(wait_msg, spinner="bouncingBar"):
190+
time.sleep(2)
191+
click.launch(target)
192+
server_thread.join()
156193

157194
except OSError as e:
158195
if e.errno == socket.errno.EADDRINUSE:

test_requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
pytest
2-
pytest-cov
1+
pytest==7.4.4
2+
pytest-cov==4.1.0
33
setuptools>=65.5.1; python_version>="3.7"
44
setuptools; python_version=="3.6"
55
Click>=8.0.2

tests/auth/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_auth_calls_login(self, process_browser_callback,
2828
get_authorization_data.assert_called_once()
2929
process_browser_callback.assert_called_once_with(auth_data[0],
3030
initial_state=auth_data[1],
31-
ctx=ANY)
31+
ctx=ANY, headless=False)
3232

3333
expected = [
3434
"",

tests/auth/test_main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def test_get_authorization_data(self):
3030
"sign_up": False,
3131
"locale": "en",
3232
"ensure_auth": False,
33-
"organization": org_id
33+
"organization": org_id,
34+
"headless": False
3435
}
3536

3637
client.create_authorization_url.assert_called_once_with(
@@ -42,7 +43,8 @@ def test_get_authorization_data(self):
4243
kwargs = {
4344
"sign_up": False,
4445
"locale": "en",
45-
"ensure_auth":False
46+
"ensure_auth":False,
47+
"headless": False
4648
}
4749

4850
client.create_authorization_url.assert_called_once_with(

tests/test_cli.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,7 @@ def test_validate_with_basic_policy_file(self):
204204
result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '3.0', '--path', path])
205205
cleaned_stdout = click.unstyle(result.stdout)
206206
msg = 'The Safety policy (3.0) file (Used for scan and system-scan commands) was successfully parsed with the following values:\n'
207-
parsed = json.dumps(
208-
{
207+
parsed = {
209208
"version": "3.0",
210209
"scan": {
211210
"max_depth": 6,
@@ -230,19 +229,19 @@ def test_validate_with_basic_policy_file(self):
230229
},
231230
"fail_scan": {
232231
"dependency_vulnerabilities": {
233-
"enabled": True,
234-
"fail_on_any_of": {
235-
"cvss_severity": [
236-
"critical",
237-
"high",
238-
"medium"
239-
],
240-
"exploitability": [
241-
"critical",
242-
"high",
243-
"medium"
244-
]
245-
}
232+
"enabled": True,
233+
"fail_on_any_of": {
234+
"cvss_severity": [
235+
"critical",
236+
"high",
237+
"medium",
238+
],
239+
"exploitability": [
240+
"critical",
241+
"high",
242+
"medium",
243+
]
244+
}
246245
}
247246
},
248247
"security_updates": {
@@ -252,12 +251,21 @@ def test_validate_with_basic_policy_file(self):
252251
]
253252
}
254253
}
255-
},
256-
indent=2
257-
) + '\n'
254+
}
258255

259-
self.assertEqual(msg + parsed, cleaned_stdout)
260-
self.assertEqual(result.exit_code, 0)
256+
msg_stdout, parsed_policy = cleaned_stdout.split('\n', 1)
257+
msg_stdout += '\n'
258+
parsed_policy = json.loads(parsed_policy.replace('\n', ''))
259+
260+
fail_scan = parsed_policy.get("fail_scan", None)
261+
self.assertIsNotNone(fail_scan)
262+
fail_of_any = fail_scan["dependency_vulnerabilities"]["fail_on_any_of"]
263+
fail_of_any["cvss_severity"] = sorted(fail_of_any["cvss_severity"])
264+
fail_of_any["exploitability"] = sorted(fail_of_any["exploitability"])
265+
266+
self.assertEqual(msg, msg_stdout)
267+
self.assertEqual(parsed, parsed_policy)
268+
self.assertEqual(result.exit_code, 0)
261269

262270

263271
def test_validate_with_policy_file_using_invalid_keyword(self):

tests/test_safety.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,8 @@ def test_get_announcements_http_ok(self, get_used_options):
494494
@patch("safety.util.get_used_options")
495495
@patch.object(click, 'get_current_context', Mock(command=Mock(name=Mock(return_value='check'))))
496496
def test_get_announcements_wrong_json_response_handling(self, get_used_options):
497+
get_used_options.return_value = {}
498+
497499
# wrong JSON structure
498500
announcements = {
499501
"type": "notice",

tox.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ isolated_build = true
55

66
[testenv]
77
deps =
8-
pytest-cov
9-
pytest
8+
pytest-cov==4.1.0
9+
pytest==7.4.4
1010

1111
commands =
1212
pytest -rP tests/ --cov=safety/ --cov-report=html

0 commit comments

Comments
 (0)