Skip to content

Commit b2f0a16

Browse files
authored
fix(cache): handle get_from_cache=None and ensure directory exists (#544)
Signed-off-by: Dylan Pulver <[email protected]>
1 parent f15d790 commit b2f0a16

5 files changed

Lines changed: 122 additions & 97 deletions

File tree

safety/safety.py

Lines changed: 67 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
from collections import defaultdict
1414
from datetime import datetime
15-
from typing import Dict, Optional, List
15+
from typing import Dict, Optional, List, Any
1616

1717
import click
1818
import requests
@@ -21,6 +21,7 @@
2121
from packaging.utils import canonicalize_name
2222
from packaging.version import parse as parse_version, Version
2323
from pydantic.json import pydantic_encoder
24+
from filelock import FileLock
2425

2526
from safety_schemas.models import Ecosystem, FileType
2627

@@ -41,34 +42,38 @@
4142
LOG = logging.getLogger(__name__)
4243

4344

44-
def get_from_cache(db_name, cache_valid_seconds=0, skip_time_verification=False):
45-
if os.path.exists(DB_CACHE_FILE):
46-
with open(DB_CACHE_FILE) as f:
47-
try:
48-
data = json.loads(f.read())
49-
if db_name in data:
45+
def get_from_cache(db_name: str, cache_valid_seconds: int = 0, skip_time_verification: bool = False) -> Optional[Dict[str, Any]]:
46+
cache_file_lock = f"{DB_CACHE_FILE}.lock"
47+
os.makedirs(os.path.dirname(cache_file_lock), exist_ok=True)
48+
lock = FileLock(cache_file_lock, timeout=10)
49+
with lock:
50+
if os.path.exists(DB_CACHE_FILE):
51+
with open(DB_CACHE_FILE) as f:
52+
try:
53+
data = json.loads(f.read())
54+
if db_name in data:
5055

51-
if "cached_at" in data[db_name]:
52-
if data[db_name]["cached_at"] + cache_valid_seconds > time.time() or skip_time_verification:
53-
LOG.debug('Getting the database from cache at %s, cache setting: %s',
54-
data[db_name]["cached_at"], cache_valid_seconds)
55-
56-
try:
57-
data[db_name]["db"]["meta"]["base_domain"] = "https://data.safetycli.com"
58-
except KeyError as e:
59-
pass
56+
if "cached_at" in data[db_name]:
57+
if data[db_name]["cached_at"] + cache_valid_seconds > time.time() or skip_time_verification:
58+
LOG.debug('Getting the database from cache at %s, cache setting: %s',
59+
data[db_name]["cached_at"], cache_valid_seconds)
6060

61-
return data[db_name]["db"]
61+
try:
62+
data[db_name]["db"]["meta"]["base_domain"] = "https://data.safetycli.com"
63+
except KeyError as e:
64+
pass
6265

63-
LOG.debug('Cached file is too old, it was cached at %s', data[db_name]["cached_at"])
64-
else:
65-
LOG.debug('There is not the cached_at key in %s database', data[db_name])
66+
return data[db_name]["db"]
6667

67-
except json.JSONDecodeError:
68-
LOG.debug('JSONDecodeError trying to get the cached database.')
69-
else:
70-
LOG.debug("Cache file doesn't exist...")
71-
return False
68+
LOG.debug('Cached file is too old, it was cached at %s', data[db_name]["cached_at"])
69+
else:
70+
LOG.debug('There is not the cached_at key in %s database', data[db_name])
71+
72+
except json.JSONDecodeError:
73+
LOG.debug('JSONDecodeError trying to get the cached database.')
74+
else:
75+
LOG.debug("Cache file doesn't exist...")
76+
return None
7277

7378

7479
def write_to_cache(db_name, data):
@@ -95,25 +100,31 @@ def write_to_cache(db_name, data):
95100
if exc.errno != errno.EEXIST:
96101
raise
97102

98-
with open(DB_CACHE_FILE, "r") as f:
99-
try:
100-
cache = json.loads(f.read())
101-
except json.JSONDecodeError:
102-
LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.')
103+
cache_file_lock = f"{DB_CACHE_FILE}.lock"
104+
lock = FileLock(cache_file_lock, timeout=10)
105+
with lock:
106+
if os.path.exists(DB_CACHE_FILE):
107+
with open(DB_CACHE_FILE, "r") as f:
108+
try:
109+
cache = json.loads(f.read())
110+
except json.JSONDecodeError:
111+
LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.')
112+
cache = {}
113+
else:
103114
cache = {}
104115

105-
with open(DB_CACHE_FILE, "w") as f:
106-
cache[db_name] = {
107-
"cached_at": time.time(),
108-
"db": data
109-
}
110-
f.write(json.dumps(cache))
111-
LOG.debug('Safety updated the cache file for %s database.', db_name)
116+
with open(DB_CACHE_FILE, "w") as f:
117+
cache[db_name] = {
118+
"cached_at": time.time(),
119+
"db": data
120+
}
121+
f.write(json.dumps(cache))
122+
LOG.debug('Safety updated the cache file for %s database.', db_name)
112123

113124

114125
def fetch_database_url(session, mirror, db_name, cached, telemetry=True,
115126
ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache=True):
116-
headers = {'schema-version': JSON_SCHEMA_VERSION, 'ecosystem': ecosystem.value}
127+
headers = {'schema-version': JSON_SCHEMA_VERSION, 'ecosystem': ecosystem.value}
117128

118129
if cached and from_cache:
119130
cached_data = get_from_cache(db_name=db_name, cache_valid_seconds=cached)
@@ -122,13 +133,13 @@ def fetch_database_url(session, mirror, db_name, cached, telemetry=True,
122133
return cached_data
123134
url = mirror + db_name
124135

125-
136+
126137
telemetry_data = {
127-
'telemetry': json.dumps(build_telemetry_data(telemetry=telemetry),
138+
'telemetry': json.dumps(build_telemetry_data(telemetry=telemetry),
128139
default=pydantic_encoder)}
129140

130141
try:
131-
r = session.get(url=url, timeout=REQUEST_TIMEOUT,
142+
r = session.get(url=url, timeout=REQUEST_TIMEOUT,
132143
headers=headers, params=telemetry_data)
133144
except requests.exceptions.ConnectionError:
134145
raise NetworkConnectionError()
@@ -205,10 +216,10 @@ def fetch_database_file(path: str, db_name: str, cached = 0,
205216

206217
if not full_path.exists():
207218
raise DatabaseFileNotFoundError(db=path)
208-
219+
209220
with open(full_path) as f:
210221
data = json.loads(f.read())
211-
222+
212223
if cached:
213224
LOG.info('Writing %s to cache because cached value was %s', db_name, cached)
214225
write_to_cache(db_name, data)
@@ -226,7 +237,7 @@ def is_valid_database(db) -> bool:
226237
return False
227238

228239

229-
def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
240+
def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
230241
ecosystem: Optional[Ecosystem] = None, from_cache=True):
231242

232243
if session.is_using_auth_credentials():
@@ -242,7 +253,7 @@ def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
242253
if is_a_remote_mirror(mirror):
243254
if ecosystem is None:
244255
ecosystem = Ecosystem.PYTHON
245-
data = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
256+
data = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
246257
telemetry=telemetry, ecosystem=ecosystem, from_cache=from_cache)
247258
else:
248259
data = fetch_database_file(mirror, db_name=db_name, cached=cached,
@@ -562,16 +573,16 @@ def compute_sec_ver(remediations, packages: Dict[str, Package], secure_vulns_by_
562573
secure_v = compute_sec_ver_for_user(package=pkg, secure_vulns_by_user=secure_vulns_by_user, db_full=db_full)
563574

564575
rem['closest_secure_version'] = get_closest_ver(secure_v, version, spec)
565-
576+
566577
upgrade = rem['closest_secure_version'].get('upper', None)
567578
downgrade = rem['closest_secure_version'].get('lower', None)
568579
recommended_version = None
569-
580+
570581
if upgrade:
571582
recommended_version = upgrade
572583
elif downgrade:
573584
recommended_version = downgrade
574-
585+
575586
rem['recommended_version'] = recommended_version
576587
rem['other_recommended_versions'] = [other_v for other_v in secure_v if
577588
other_v != str(recommended_version)]
@@ -645,12 +656,12 @@ def process_fixes(files, remediations, auto_remediation_limit, output, no_output
645656

646657
def process_fixes_scan(file_to_fix, to_fix_spec, auto_remediation_limit, output, no_output=True, prompt=False):
647658
to_fix_remediations = []
648-
659+
649660
def get_remmediation_from(spec):
650661
upper = None
651662
lower = None
652663
recommended = None
653-
664+
654665
try:
655666
upper = Version(spec.remediation.closest_secure.upper) if spec.remediation.closest_secure.upper else None
656667
except Exception as e:
@@ -664,15 +675,15 @@ def get_remmediation_from(spec):
664675
try:
665676
recommended = Version(spec.remediation.recommended)
666677
except Exception as e:
667-
LOG.error(f'Error getting recommended version for remediation, ignoring', exc_info=True)
678+
LOG.error(f'Error getting recommended version for remediation, ignoring', exc_info=True)
668679

669680
return {
670681
"vulnerabilities_found": spec.remediation.vulnerabilities_found,
671682
"version": next(iter(spec.specifier)).version if spec.is_pinned() else None,
672683
"requirement": spec,
673684
"more_info_url": spec.remediation.more_info_url,
674685
"closest_secure_version": {
675-
'upper': upper,
686+
'upper': upper,
676687
'lower': lower
677688
},
678689
"recommended_version": recommended,
@@ -690,7 +701,7 @@ def get_remmediation_from(spec):
690701
'files': {str(file_to_fix.location): {'content': None, 'fixes': {'TO_SKIP': [], 'TO_APPLY': [], 'TO_CONFIRM': []}, 'supported': False, 'filename': file_to_fix.location.name}},
691702
'dependencies': defaultdict(dict),
692703
}
693-
704+
694705
fixes = apply_fixes(requirements, output, no_output, prompt, scan_flow=True, auto_remediation_limit=auto_remediation_limit)
695706

696707
return fixes
@@ -822,7 +833,7 @@ def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto
822833
for name, data in requirements['files'].items():
823834
output = [('', {}),
824835
(f"Analyzing {name}... [{get_fix_opt_used_msg(auto_remediation_limit)} limit]", {'styling': {'bold': True}, 'start_line_decorator': '->', 'indent': ' '})]
825-
836+
826837
r_skip = data['fixes']['TO_SKIP']
827838
r_apply = data['fixes']['TO_APPLY']
828839
r_confirm = data['fixes']['TO_CONFIRM']
@@ -901,7 +912,7 @@ def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto
901912
else:
902913
not_supported_filename = data.get('filename', name)
903914
output.append(
904-
(f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.",
915+
(f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.",
905916
{'start_line_decorator': ' -', 'indent': ' '}))
906917
output.append(('', {}))
907918

@@ -999,7 +1010,7 @@ def review(*, report=None, params=None):
9991010

10001011
@sync_safety_context
10011012
def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True):
1002-
1013+
10031014
if db_mirror:
10041015
mirrors = [db_mirror]
10051016
else:

0 commit comments

Comments
 (0)