Skip to content

Commit d741a15

Browse files
committed
Return only the file descriptor when creating temp files
Signed-off-by: Ivan Kanakarakis <[email protected]>
1 parent 2ce87f0 commit d741a15

File tree

2 files changed

+50
-68
lines changed

2 files changed

+50
-68
lines changed

src/saml2/entity.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def __init__(self, entity_type, config=None, config_file="",
144144
if _val.startswith("http"):
145145
r = requests.request("GET", _val)
146146
if r.status_code == 200:
147-
_, filename = make_temp(r.text, ".pem", False)
148-
setattr(self.config, item, filename)
147+
tmp = make_temp(r.text, ".pem", False)
148+
setattr(self.config, item, tmp.name)
149149
else:
150150
raise Exception(
151151
"Could not fetch certificate from %s" % _val)
@@ -568,8 +568,8 @@ def _encrypt_assertion(self, encrypt_cert, sp_entity_id, response,
568568
_cert = "%s%s" % (begin_cert, _cert)
569569
if end_cert not in _cert:
570570
_cert = "%s%s" % (_cert, end_cert)
571-
_, cert_file = make_temp(_cert.encode('ascii'), decode=False)
572-
response = self.sec.encrypt_assertion(response, cert_file,
571+
tmp = make_temp(_cert.encode('ascii'), decode=False)
572+
response = self.sec.encrypt_assertion(response, tmp.name,
573573
pre_encryption_part(),
574574
node_xpath=node_xpath)
575575
return response

src/saml2/sigver.py

Lines changed: 46 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -336,31 +336,31 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
336336
return instance
337337

338338

339-
def make_temp(string, suffix='', decode=True, delete=True):
340-
""" xmlsec needs files in some cases where only strings exist, hence the
341-
need for this function. It creates a temporary file with the
342-
string as only content.
339+
def make_temp(content, suffix="", decode=True, delete=True):
340+
"""
341+
Create a temporary file with the given content.
342+
343+
This is needed by xmlsec in some cases where only strings exist when files
344+
are expected.
343345
344-
:param string: The information to be placed in the file
346+
:param content: The information to be placed in the file
345347
:param suffix: The temporary file might have to have a specific
346348
suffix in certain circumstances.
347-
:param decode: The input string might be base64 coded. If so it
349+
:param decode: The input content might be base64 coded. If so it
348350
must, in some cases, be decoded before being placed in the file.
349351
:return: 2-tuple with file pointer ( so the calling function can
350352
close the file) and filename (which is for instance needed by the
351353
xmlsec function).
352354
"""
353-
ntf = NamedTemporaryFile(suffix=suffix, delete=delete)
354-
# Python3 tempfile requires byte-like object
355-
if not isinstance(string, six.binary_type):
356-
string = string.encode('utf-8')
357-
358-
if decode:
359-
ntf.write(base64.b64decode(string))
360-
else:
361-
ntf.write(string)
355+
content_encoded = (
356+
content.encode("utf-8") if not isinstance(content, six.binary_type) else content
357+
)
358+
content_raw = base64.b64decode(content_encoded) if decode else content_encoded
359+
delete_tmpfiles = delete
360+
ntf = NamedTemporaryFile(suffix=suffix, delete=delete_tmpfiles)
361+
ntf.write(content_raw)
362362
ntf.seek(0)
363-
return ntf, ntf.name
363+
return ntf
364364

365365

366366
def split_len(seq, length):
@@ -722,13 +722,13 @@ def encrypt(self, text, recv_key, template, session_key_type, xpath=''):
722722
:return:
723723
"""
724724
logger.debug('Encryption input len: %d', len(text))
725-
f, fil = make_temp(text, decode=False)
725+
tmp = make_temp(text, decode=False)
726726
com_list = [
727727
self.xmlsec,
728728
'--encrypt',
729729
'--pubkey-cert-pem', recv_key,
730730
'--session-key', session_key_type,
731-
'--xml-data', fil,
731+
'--xml-data', tmp.name,
732732
]
733733

734734
if xpath:
@@ -759,9 +759,8 @@ def encrypt_assertion(self, statement, enc_key, template, key_type='des-192', no
759759
if isinstance(statement, SamlBase):
760760
statement = pre_encrypt_assertion(statement)
761761

762-
f, fil = make_temp(
763-
_str(statement), decode=False)
764-
t, tmpl = make_temp(_str(template), decode=False)
762+
tmp = make_temp(_str(statement), decode=False)
763+
tmp2 = make_temp(_str(template), decode=False)
765764

766765
if not node_xpath:
767766
node_xpath = ASSERT_XPATH
@@ -771,15 +770,15 @@ def encrypt_assertion(self, statement, enc_key, template, key_type='des-192', no
771770
'--encrypt',
772771
'--pubkey-cert-pem', enc_key,
773772
'--session-key', key_type,
774-
'--xml-data', fil,
773+
'--xml-data', tmp.name,
775774
'--node-xpath', node_xpath,
776775
]
777776

778777
if node_id:
779778
com_list.extend(['--node-id', node_id])
780779

781780
try:
782-
(_stdout, _stderr, output) = self._run_xmlsec(com_list, [tmpl])
781+
(_stdout, _stderr, output) = self._run_xmlsec(com_list, [tmp2.name])
783782
except XmlsecError as e:
784783
six.raise_from(EncryptError(com_list), e)
785784

@@ -794,7 +793,7 @@ def decrypt(self, enctext, key_file, id_attr):
794793
"""
795794

796795
logger.debug('Decrypt input len: %d', len(enctext))
797-
_, fil = make_temp(enctext, decode=False)
796+
tmp = make_temp(enctext, decode=False)
798797

799798
com_list = [
800799
self.xmlsec,
@@ -805,7 +804,7 @@ def decrypt(self, enctext, key_file, id_attr):
805804
]
806805

807806
try:
808-
(_stdout, _stderr, output) = self._run_xmlsec(com_list, [fil])
807+
(_stdout, _stderr, output) = self._run_xmlsec(com_list, [tmp.name])
809808
except XmlsecError as e:
810809
six.raise_from(DecryptError(com_list), e)
811810

@@ -826,12 +825,7 @@ def sign_statement(self, statement, node_name, key_file, node_id, id_attr):
826825
if isinstance(statement, SamlBase):
827826
statement = str(statement)
828827

829-
_, fil = make_temp(
830-
statement,
831-
suffix='.xml',
832-
decode=False,
833-
delete=self._xmlsec_delete_tmpfiles,
834-
)
828+
tmp = make_temp(statement, suffix=".xml", decode=False, delete=self._xmlsec_delete_tmpfiles)
835829

836830
com_list = [
837831
self.xmlsec,
@@ -845,7 +839,7 @@ def sign_statement(self, statement, node_name, key_file, node_id, id_attr):
845839
com_list.extend(['--node-id', node_id])
846840

847841
try:
848-
(stdout, stderr, output) = self._run_xmlsec(com_list, [fil])
842+
(stdout, stderr, output) = self._run_xmlsec(com_list, [tmp.name])
849843
except XmlsecError as e:
850844
raise SignatureError(com_list)
851845

@@ -872,12 +866,7 @@ def validate_signature(self, signedtext, cert_file, cert_type, node_name, node_i
872866
if not isinstance(signedtext, six.binary_type):
873867
signedtext = signedtext.encode('utf-8')
874868

875-
_, fil = make_temp(
876-
signedtext,
877-
suffix='.xml',
878-
decode=False,
879-
delete=self._xmlsec_delete_tmpfiles,
880-
)
869+
tmp = make_temp(signedtext, suffix=".xml", decode=False, delete=self._xmlsec_delete_tmpfiles)
881870

882871
com_list = [
883872
self.xmlsec,
@@ -892,7 +881,7 @@ def validate_signature(self, signedtext, cert_file, cert_type, node_name, node_i
892881
com_list.extend(['--node-id', node_id])
893882

894883
try:
895-
(_stdout, stderr, _output) = self._run_xmlsec(com_list, [fil])
884+
(_stdout, stderr, _output) = self._run_xmlsec(com_list, [tmp.name])
896885
except XmlsecError as e:
897886
six.raise_from(SignatureError(com_list), e)
898887

@@ -1369,15 +1358,16 @@ def decrypt_keys(self, enctext, keys=None, id_attr=''):
13691358
if not isinstance(keys, list):
13701359
keys = [keys]
13711360

1372-
keys = [key for key in keys if key]
1373-
for key in keys:
1374-
if not isinstance(key, six.binary_type):
1375-
key = key.encode("ascii")
1376-
key_file, _ = make_temp(key, decode=False)
1377-
key_files.append(key_file)
1361+
keys_filtered = (key for key in keys if key)
1362+
keys_encoded = (
1363+
key.encode("ascii") if not isinstance(key, six.binary_type) else key
1364+
for key in keys_filtered
1365+
)
1366+
key_files = list(make_temp(key, decode=False) for key in keys_encoded)
1367+
key_file_names = list(tmp.name for tmp in key_files)
13781368

13791369
try:
1380-
dectext = self.decrypt(enctext, key_file=[x.name for x in key_files], id_attr=id_attr)
1370+
dectext = self.decrypt(enctext, key_file=key_file_names, id_attr=id_attr)
13811371
except DecryptError as e:
13821372
raise
13831373
else:
@@ -1462,14 +1452,9 @@ def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, origdoc=None,
14621452

14631453
for cert in _certs:
14641454
if isinstance(cert, six.string_types):
1465-
certs.append(
1466-
make_temp(
1467-
pem_format(cert),
1468-
suffix='.pem',
1469-
decode=False,
1470-
delete=self._xmlsec_delete_tmpfiles,
1471-
)
1472-
)
1455+
content = pem_format(cert)
1456+
tmp = make_temp(content, suffix=".pem", decode=False, delete=self._xmlsec_delete_tmpfiles)
1457+
certs.append(tmp)
14731458
else:
14741459
certs.append(cert)
14751460
else:
@@ -1478,12 +1463,7 @@ def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, origdoc=None,
14781463
if not certs and not self.only_use_keys_in_metadata:
14791464
logger.debug('==== Certs from instance ====')
14801465
certs = [
1481-
make_temp(
1482-
pem_format(cert),
1483-
suffix='.pem',
1484-
decode=False,
1485-
delete=self._xmlsec_delete_tmpfiles,
1486-
)
1466+
make_temp(content=pem_format(cert), suffix=".pem", decode=False, delete=self._xmlsec_delete_tmpfiles)
14871467
for cert in cert_from_instance(item)
14881468
]
14891469
else:
@@ -1495,12 +1475,12 @@ def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, origdoc=None,
14951475
verified = False
14961476
last_pem_file = None
14971477

1498-
for _, pem_file in certs:
1478+
for pem_fd in certs:
14991479
try:
1500-
last_pem_file = pem_file
1480+
last_pem_file = pem_fd.name
15011481
if self.verify_signature(
15021482
decoded_xml,
1503-
pem_file,
1483+
pem_fd.name,
15041484
node_name=node_name,
15051485
node_id=item.id,
15061486
id_attr=id_attr):
@@ -1670,7 +1650,9 @@ def sign_statement(self, statement, node_name, key=None, key_file=None, node_id=
16701650
id_attr = self.id_attr
16711651

16721652
if not key_file and key:
1673-
_, key_file = make_temp(str(key).encode(), '.pem')
1653+
content = str(key).encode()
1654+
tmp = make_temp(content, suffix=".pem")
1655+
key_file = tmp.name
16741656

16751657
if not key and not key_file:
16761658
key_file = self.key_file

0 commit comments

Comments
 (0)