Skip to content

Commit 1fd509e

Browse files
committed
refactor: use Path objects instead of strings
Don't convert Paths to Paths. Simplify code and take advantage of Path object methods. Signed-off-by: Rafal Ilnicki <[email protected]>
1 parent 20e33f6 commit 1fd509e

16 files changed

+55
-73
lines changed

cve_bin_tool/cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,7 @@ def main(argv=None):
830830
error_mode=error_mode,
831831
)
832832

833-
# if OLD_CACHE_DIR (from cvedb.py) exists, print warning
834-
if Path(OLD_CACHE_DIR).exists():
833+
if OLD_CACHE_DIR.exists():
835834
LOGGER.warning(
836835
f"Obsolete cache dir {OLD_CACHE_DIR} is no longer needed and can be removed."
837836
)

cve_bin_tool/cve_scanner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import sys
66
from collections import defaultdict
77
from logging import Logger
8-
from pathlib import Path
98
from string import ascii_lowercase
109
from typing import DefaultDict, Dict, List
1110

@@ -31,7 +30,7 @@ class CVEScanner:
3130
all_cve_version_info: Dict[str, VersionInfo]
3231

3332
RANGE_UNSET: str = ""
34-
dbname: str = str(Path(DISK_LOCATION_DEFAULT) / DBNAME)
33+
dbname: str = str(DISK_LOCATION_DEFAULT / DBNAME)
3534
CONSOLE: Console = Console(file=sys.stderr, theme=cve_theme)
3635
ALPHA_TO_NUM: Dict[str, int] = dict(zip(ascii_lowercase, range(26)))
3736

cve_bin_tool/data_sources/curl_source.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import json
77
import logging
8-
from pathlib import Path
98

109
import aiohttp
1110

@@ -66,7 +65,7 @@ async def download_curl_vulnerabilities(self, session: RateLimiter) -> None:
6665
async with await session.get(self.DATA_SOURCE_LINK) as response:
6766
response.raise_for_status()
6867
self.vulnerability_data = await response.json()
69-
path = Path(str(Path(self.cachedir) / "vuln.json"))
68+
path = self.cachedir / "vuln.json"
7069
filepath = path.resolve()
7170
async with FileIO(filepath, "w") as f:
7271
await f.write(json.dumps(self.vulnerability_data, indent=4))

cve_bin_tool/data_sources/epss_source.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import csv
77
import gzip
88
import logging
9-
import os
109
from datetime import datetime, timedelta
1110
from io import StringIO
1211
from pathlib import Path
@@ -34,8 +33,8 @@ def __init__(self, error_mode=ErrorMode.TruncTrace):
3433
self.error_mode = error_mode
3534
self.cachedir = self.CACHEDIR
3635
self.backup_cachedir = self.BACKUPCACHEDIR
37-
self.epss_path = str(Path(self.cachedir) / "epss")
38-
self.file_name = os.path.join(self.epss_path, "epss_scores-current.csv")
36+
self.epss_path = self.cachedir / "epss"
37+
self.file_name = self.epss_path / "epss_scores-current.csv"
3938
self.source_name = self.SOURCE
4039

4140
async def update_epss(self):
@@ -58,11 +57,11 @@ async def download_epss_data(self):
5857
"""Downloads the EPSS CSV file and saves it to the local filesystem.
5958
The download is only performed if the file is older than 24 hours.
6059
"""
61-
os.makedirs(self.epss_path, exist_ok=True)
60+
self.epss_path.mkdir(parents=True, exist_ok=True)
6261
# Check if the file exists
63-
if os.path.exists(self.file_name):
62+
if self.file_name.exists():
6463
# Get the modification time of the file
65-
modified_time = os.path.getmtime(self.file_name)
64+
modified_time = self.file_name.stat().st_mtime
6665
last_modified = datetime.fromtimestamp(modified_time)
6766

6867
# Calculate the time difference between now and the last modified time
@@ -80,8 +79,7 @@ async def download_epss_data(self):
8079
decompressed_data = gzip.decompress(await response.read())
8180

8281
# Save the downloaded data to the file
83-
with open(self.file_name, "wb") as file:
84-
file.write(decompressed_data)
82+
self.file_name.write_bytes(decompressed_data)
8583

8684
except aiohttp.ClientError as e:
8785
self.LOGGER.error(f"An error occurred during updating epss {e}")
@@ -102,8 +100,7 @@ async def download_epss_data(self):
102100
decompressed_data = gzip.decompress(await response.read())
103101

104102
# Save the downloaded data to the file
105-
with open(self.file_name, "wb") as file:
106-
file.write(decompressed_data)
103+
self.file_name.write_bytes(decompressed_data)
107104

108105
except aiohttp.ClientError as e:
109106
self.LOGGER.error(f"An error occurred during downloading epss {e}")
@@ -114,9 +111,8 @@ def parse_epss_data(self, file_path=None):
114111
if file_path is None:
115112
file_path = self.file_name
116113

117-
with open(file_path) as file:
118-
# Read the content of the CSV file
119-
decoded_data = file.read()
114+
# Read the content of the CSV file
115+
decoded_data = Path(file_path).read_text()
120116

121117
# Create a CSV reader to read the data from the decoded CSV content
122118
reader = csv.reader(StringIO(decoded_data), delimiter=",")

cve_bin_tool/data_sources/gad_source.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import io
99
import re
1010
import zipfile
11-
from pathlib import Path
1211

1312
import aiohttp
1413
import yaml
@@ -39,7 +38,7 @@ def __init__(
3938
):
4039
self.cachedir = self.CACHEDIR
4140
self.slugs = None
42-
self.gad_path = str(Path(self.cachedir) / "gad")
41+
self.gad_path = self.cachedir / "gad"
4342
self.source_name = self.SOURCE
4443

4544
self.error_mode = error_mode
@@ -90,8 +89,8 @@ async def fetch_cves(self):
9089

9190
self.db = cvedb.CVEDB()
9291

93-
if not Path(self.gad_path).exists():
94-
Path(self.gad_path).mkdir()
92+
if not self.gad_path.exists():
93+
self.gad_path.mkdir()
9594
# As no data, force full update
9695
self.incremental_update = False
9796

@@ -155,7 +154,7 @@ async def fetch_cves(self):
155154
async def update_cve_entries(self):
156155
"""Updates CVE entries from CVEs in cache."""
157156

158-
p = Path(self.gad_path).glob("**/*")
157+
p = self.gad_path.glob("**/*")
159158
# Need to find files which are new to the cache
160159
last_update_timestamp = (
161160
self.time_of_last_update.timestamp()

cve_bin_tool/data_sources/nvd_source.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import logging
1313
import re
1414
import sqlite3
15-
from pathlib import Path
1615

1716
import aiohttp
1817
from rich.progress import track
@@ -26,7 +25,6 @@
2625
NVD_FILENAME_TEMPLATE,
2726
)
2827
from cve_bin_tool.error_handler import (
29-
AttemptedToWriteOutsideCachedir,
3028
CVEDataForYearNotInCache,
3129
ErrorHandler,
3230
ErrorMode,
@@ -77,7 +75,7 @@ def __init__(
7775
self.source_name = self.SOURCE
7876

7977
# set up the db if needed
80-
self.dbpath = str(Path(self.cachedir) / DBNAME)
78+
self.dbpath = self.cachedir / DBNAME
8179
self.connection: sqlite3.Connection | None = None
8280
self.session = session
8381
self.cve_count = -1
@@ -543,12 +541,9 @@ async def cache_update(
543541
Update the cache for a single year of NVD data.
544542
"""
545543
filename = url.split("/")[-1]
546-
# Ensure we only write to files within the cachedir
547-
cache_path = Path(self.cachedir)
548-
filepath = Path(str(cache_path / filename)).resolve()
549-
if not str(filepath).startswith(str(cache_path.resolve())):
550-
with ErrorHandler(mode=self.error_mode, logger=self.LOGGER):
551-
raise AttemptedToWriteOutsideCachedir(filepath)
544+
cache_path = self.cachedir
545+
filepath = cache_path / filename
546+
552547
# Validate the contents of the cached file
553548
if filepath.is_file():
554549
# Validate the sha and write out
@@ -600,7 +595,7 @@ def load_nvd_year(self, year: int) -> dict[str, str | object]:
600595
Return the dict of CVE data for the given year.
601596
"""
602597

603-
filename = Path(self.cachedir) / self.NVDCVE_FILENAME_TEMPLATE.format(year)
598+
filename = self.cachedir / self.NVDCVE_FILENAME_TEMPLATE.format(year)
604599
# Check if file exists
605600
if not filename.is_file():
606601
with ErrorHandler(mode=self.error_mode, logger=self.LOGGER):
@@ -619,5 +614,5 @@ def nvd_years(self) -> list[int]:
619614
"""
620615
return sorted(
621616
int(filename.split(".")[-3].split("-")[-1])
622-
for filename in glob.glob(str(Path(self.cachedir) / "nvdcve-1.1-*.json.gz"))
617+
for filename in glob.glob(str(self.cachedir / "nvdcve-1.1-*.json.gz"))
623618
)

cve_bin_tool/data_sources/osv_source.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import datetime
88
import io
99
import json
10-
import os
1110
import shutil
1211
import zipfile
1312
from pathlib import Path
@@ -25,7 +24,7 @@
2524

2625
def find_gsutil():
2726
gsutil_path = shutil.which("gsutil")
28-
if not os.path.exists(gsutil_path):
27+
if not Path(gsutil_path).exists():
2928
raise FileNotFoundError(
3029
"gsutil not found. Did you need to install requirements or activate a venv where gsutil is installed?"
3130
)
@@ -46,7 +45,7 @@ def __init__(
4645
):
4746
self.cachedir = self.CACHEDIR
4847
self.ecosystems = None
49-
self.osv_path = str(Path(self.cachedir) / "osv")
48+
self.osv_path = self.cachedir / "osv"
5049
self.source_name = self.SOURCE
5150

5251
self.error_mode = error_mode
@@ -113,7 +112,7 @@ async def get_ecosystem_incremental(self, ecosystem, time_of_last_update, sessio
113112
tasks.append(task)
114113

115114
for r in await asyncio.gather(*tasks):
116-
filepath = Path(self.osv_path) / (r.get("id") + ".json")
115+
filepath = self.osv_path / (r.get("id") + ".json")
117116
r = json.dumps(r)
118117

119118
async with FileIO(filepath, "w") as f:
@@ -158,9 +157,9 @@ async def get_totalfiles(self, ecosystem):
158157

159158
gsutil_path = find_gsutil() # use helper function
160159
gs_file = self.gs_url + ecosystem + "/all.zip"
161-
await aio_run_command([gsutil_path, "cp", gs_file, self.osv_path])
160+
await aio_run_command([gsutil_path, "cp", gs_file, str(self.osv_path)])
162161

163-
zip_path = Path(self.osv_path) / "all.zip"
162+
zip_path = self.osv_path / "all.zip"
164163
totalfiles = 0
165164

166165
with zipfile.ZipFile(zip_path, "r") as z:
@@ -179,8 +178,8 @@ async def fetch_cves(self):
179178

180179
self.db = cvedb.CVEDB()
181180

182-
if not Path(self.osv_path).exists():
183-
Path(self.osv_path).mkdir()
181+
if not self.osv_path.exists():
182+
self.osv_path.mkdir()
184183
# As no data, force full update
185184
self.incremental_update = False
186185

@@ -239,7 +238,7 @@ async def fetch_cves(self):
239238
async def update_cve_entries(self):
240239
"""Updates CVE entries from CVEs in cache"""
241240

242-
p = Path(self.osv_path).glob("**/*")
241+
p = self.osv_path.glob("**/*")
243242
# Need to find files which are new to the cache
244243
last_update_timestamp = (
245244
self.time_of_last_update.timestamp()

cve_bin_tool/data_sources/purl2cpe_source.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import zipfile
44
from io import BytesIO
5-
from pathlib import Path
65

76
import aiohttp
87

@@ -25,7 +24,7 @@ def __init__(
2524
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
2625
):
2726
self.cachedir = self.CACHEDIR
28-
self.purl2cpe_path = str(Path(self.cachedir) / "purl2cpe")
27+
self.purl2cpe_path = self.cachedir / "purl2cpe"
2928
self.source_name = self.SOURCE
3029
self.error_mode = error_mode
3130
self.incremental_update = incremental_update
@@ -36,8 +35,8 @@ async def fetch_cves(self):
3635
"""Fetches PURL2CPE database and places it in purl2cpe_path."""
3736
LOGGER.info("Getting PURL2CPE data...")
3837

39-
if not Path(self.purl2cpe_path).exists():
40-
Path(self.purl2cpe_path).mkdir()
38+
if not self.purl2cpe_path.exists():
39+
self.purl2cpe_path.mkdir()
4140

4241
if not self.session:
4342
connector = aiohttp.TCPConnector(limit_per_host=10)

cve_bin_tool/data_sources/redhat_source.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import datetime
55
import json
6-
from pathlib import Path
76

87
import aiohttp
98

@@ -28,7 +27,7 @@ def __init__(
2827
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
2928
):
3029
self.cachedir = self.CACHEDIR
31-
self.redhat_path = str(Path(self.cachedir) / "redhat")
30+
self.redhat_path = self.cachedir / "redhat"
3231
self.source_name = self.SOURCE
3332

3433
self.error_mode = error_mode
@@ -57,7 +56,7 @@ async def store_data(self, content):
5756
"""Asynchronously stores CVE data in separate JSON files, excluding entries without a CVE ID."""
5857
for c in content:
5958
if c["CVE"] != "":
60-
filepath = Path(self.redhat_path) / (str(c["CVE"]) + ".json")
59+
filepath = self.redhat_path / (str(c["CVE"]) + ".json")
6160
r = json.dumps(c)
6261
async with FileIO(filepath, "w") as f:
6362
await f.write(r)
@@ -73,8 +72,8 @@ async def fetch_cves(self):
7372

7473
self.db = cvedb.CVEDB()
7574

76-
if not Path(self.redhat_path).exists():
77-
Path(self.redhat_path).mkdir()
75+
if not self.redhat_path.exists():
76+
self.redhat_path.mkdir()
7877
# As no data, force full update
7978
self.incremental_update = False
8079

@@ -121,7 +120,7 @@ async def fetch_cves(self):
121120
async def update_cve_entries(self):
122121
"""Updates CVE entries from CVEs in cache."""
123122

124-
p = Path(self.redhat_path).glob("**/*")
123+
p = self.redhat_path.glob("**/*")
125124
# Need to find files which are new to the cache
126125
last_update_timestamp = (
127126
self.time_of_last_update.timestamp()

cve_bin_tool/data_sources/rsd_source.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import io
99
import json
1010
import zipfile
11-
from pathlib import Path
1211

1312
import aiohttp
1413
from cvss import CVSS2, CVSS3
@@ -36,7 +35,7 @@ def __init__(
3635
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
3736
):
3837
self.cachedir = self.CACHEDIR
39-
self.rsd_path = str(Path(self.cachedir) / "rsd")
38+
self.rsd_path = self.cachedir / "rsd"
4039
self.source_name = self.SOURCE
4140

4241
self.error_mode = error_mode
@@ -71,8 +70,8 @@ async def fetch_cves(self):
7170

7271
self.db = cvedb.CVEDB()
7372

74-
if not Path(self.rsd_path).exists():
75-
Path(self.rsd_path).mkdir()
73+
if not self.rsd_path.exists():
74+
self.rsd_path.mkdir()
7675

7776
if not self.session:
7877
connector = aiohttp.TCPConnector(limit_per_host=19)
@@ -133,7 +132,7 @@ async def fetch_cves(self):
133132
async def update_cve_entries(self):
134133
"""Updates CVE entries from CVEs in cache."""
135134

136-
p = Path(self.rsd_path).glob("**/*")
135+
p = self.rsd_path.glob("**/*")
137136
# Need to find files which are new to the cache
138137
last_update_timestamp = (
139138
self.time_of_last_update.timestamp()

cve_bin_tool/helper_script.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import textwrap
1010
from collections import ChainMap
1111
from logging import Logger
12-
from pathlib import Path
1312
from typing import MutableMapping
1413

1514
from rich import print as rprint
@@ -46,7 +45,7 @@ def __init__(
4645

4746
# for setting the database
4847
self.connection = None
49-
self.dbpath = str(Path(DISK_LOCATION_DEFAULT) / DBNAME)
48+
self.dbpath = DISK_LOCATION_DEFAULT / DBNAME
5049

5150
# for extraction
5251
self.walker = DirWalk().walk

0 commit comments

Comments
 (0)