diff --git a/Lib/http/server.py b/Lib/http/server.py index 64f766f9bc2c1b..9da5e31463a8bb 100644 --- a/Lib/http/server.py +++ b/Lib/http/server.py @@ -1336,7 +1336,7 @@ def test(HandlerClass=BaseHTTPRequestHandler, print("\nKeyboard interrupt received, exiting.") sys.exit(0) -if __name__ == '__main__': +def _main(args=None): import argparse import contextlib @@ -1362,7 +1362,7 @@ def test(HandlerClass=BaseHTTPRequestHandler, parser.add_argument('port', default=8000, type=int, nargs='?', help='bind to this port ' '(default: %(default)s)') - args = parser.parse_args() + args = parser.parse_args(args) if not args.tls_cert and args.tls_key: parser.error("--tls-key requires --tls-cert to be set") @@ -1407,3 +1407,7 @@ def finish_request(self, request, client_address): tls_key=args.tls_key, tls_password=tls_key_password, ) + + +if __name__ == '__main__': + _main() diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index 2cafa4e45a1313..1f1a571f097ba9 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -8,8 +8,10 @@ SimpleHTTPRequestHandler, CGIHTTPRequestHandler from http import server, HTTPStatus +import contextlib import os import socket +import subprocess import sys import re import base64 @@ -21,6 +23,7 @@ import html import http, http.client import urllib.parse +import urllib.request import tempfile import time import datetime @@ -33,6 +36,8 @@ from test.support import ( is_apple, import_helper, os_helper, requires_subprocess, threading_helper ) +from test.support.script_helper import spawn_python, kill_python +from test.support.socket_helper import find_unused_port try: import ssl @@ -1536,6 +1541,274 @@ def test_server_test_ipv4(self, _): self.assertEqual(mock_server.address_family, socket.AF_INET) +class CommandLineTestCase(unittest.TestCase): + default_port = 8000 + default_bind = None + default_protocol = 'HTTP/1.0' + default_handler = SimpleHTTPRequestHandler + default_server = unittest.mock.ANY + tls_cert = certdata_file('ssl_cert.pem') + tls_key = certdata_file('ssl_key.pem') + tls_password = 'somepass' + args = { + 'HandlerClass': default_handler, + 'ServerClass': default_server, + 'protocol': default_protocol, + 'port': default_port, + 'bind': default_bind, + 'tls_cert': None, + 'tls_key': None, + 'tls_password': None, + } + + def setUp(self): + super().setUp() + self.tls_password_file = tempfile.mktemp() + with open(self.tls_password_file, 'wb') as f: + f.write(self.tls_password.encode()) + self.addCleanup(os_helper.unlink, self.tls_password_file) + + def invoke_httpd(self, *args): + output = StringIO() + with contextlib.redirect_stdout(output), \ + contextlib.redirect_stderr(output): + server._main(args) + return output.getvalue() + + @mock.patch('http.server.test') + def test_port_flag(self, mock_func): + ports = [8000, 65535] + for port in ports: + with self.subTest(port=port): + self.invoke_httpd(str(port)) + self.args['port'] = port + mock_func.assert_called_once_with(**self.args) + self.args['port'] = self.default_port + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_directory_flag(self, mock_func): + options = ['-d', '--directory'] + directories = ['.', '/foo', '\\bar', '/', + 'C:\\', 'C:\\foo', 'C:\\bar', + '/home/user', './foo/foo2', 'D:\\foo\\bar'] + for flag in options: + for directory in directories: + with self.subTest(flag=flag, directory=directory): + self.invoke_httpd(flag, directory) + mock_func.assert_called_once_with(**self.args) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_bind_flag(self, mock_func): + options = ['-b', '--bind'] + bind_addresses = ['localhost', '127.0.0.1', '::1', + '0.0.0.0', '8.8.8.8',] + for flag in options: + for bind_address in bind_addresses: + with self.subTest(flag=flag, bind_address=bind_address): + self.invoke_httpd(flag, bind_address) + self.args['bind'] = bind_address + mock_func.assert_called_once_with(**self.args) + self.args['bind'] = self.default_bind + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_protocol_flag(self, mock_func): + options = ['-p', '--protocol'] + protocols = ['HTTP/1.0', 'HTTP/1.1', 'HTTP/2.0', 'HTTP/3.0'] + for flag in options: + for protocol in protocols: + with self.subTest(flag=flag, protocol=protocol): + self.invoke_httpd(flag, protocol) + self.args['protocol'] = protocol + mock_func.assert_called_once_with(**self.args) + self.args['protocol'] = self.default_protocol + mock_func.reset_mock() + + # TODO: This test should be removed once the CGI component is removed(3.15) + @mock.patch('http.server.test') + def test_cgi_flag(self, mock_func): + self.invoke_httpd('--cgi') + mock_func.assert_called_once_with(HandlerClass=CGIHTTPRequestHandler, + ServerClass=self.default_server, + protocol=self.default_protocol, + port=self.default_port, + bind=self.default_bind, + tls_cert=None, + tls_key=None, + tls_password=None) + + @unittest.skipIf(ssl is None, "requires ssl") + @mock.patch('http.server.test') + def test_tls_flag(self, mock_func): + tls_cert_options = ['--tls-cert', ] + tls_key_options = ['--tls-key', ] + tls_password_options = ['--tls-password-file', ] + # Normal: --tls-cert and --tls-key + + for tls_cert_option in tls_cert_options: + for tls_key_option in tls_key_options: + self.invoke_httpd(tls_cert_option, self.tls_cert, + tls_key_option, self.tls_key) + self.args['tls_cert'] = self.tls_cert + self.args['tls_key'] = self.tls_key + mock_func.assert_called_once_with(**self.args) + self.args['tls_cert'] = None + self.args['tls_key'] = None + mock_func.reset_mock() + + # Normal: --tls-cert, --tls-key and --tls-password-file + + for tls_cert_option in tls_cert_options: + for tls_key_option in tls_key_options: + for tls_password_option in tls_password_options: + self.invoke_httpd(tls_cert_option, + self.tls_cert, + tls_key_option, + self.tls_key, + tls_password_option, + self.tls_password_file) + self.args['tls_cert'] = self.tls_cert + self.args['tls_key'] = self.tls_key + self.args['tls_password'] = self.tls_password + mock_func.assert_called_once_with(**self.args) + self.args['tls_cert'] = None + self.args['tls_key'] = None + self.args['tls_password'] = None + mock_func.reset_mock() + + # Abnormal: --tls-key without --tls-cert + + for tls_key_option in tls_key_options: + for tls_cert_option in tls_cert_options: + with self.assertRaises(SystemExit): + self.invoke_httpd(tls_key_option, self.tls_key) + mock_func.reset_mock() + + # Abnormal: --tls-password-file without --tls-cert + + for tls_password_option in tls_password_options: + with self.assertRaises(SystemExit): + self.invoke_httpd(tls_password_option, self.tls_password_file) + mock_func.reset_mock() + + # Abnormal: --tls-password-file cannot be opened + + non_existent_file = 'non_existent_file' + for tls_password_option in tls_password_options: + for tls_cert_option in tls_cert_options: + with self.assertRaises(SystemExit): + self.invoke_httpd(tls_cert_option, + self.tls_cert, + tls_password_option, + non_existent_file) + + @mock.patch('http.server.test') + def test_no_arguments(self, mock_func): + self.invoke_httpd() + mock_func.assert_called_once_with(HandlerClass=self.default_handler, + ServerClass=self.default_server, + protocol=self.default_protocol, + port=self.default_port, + bind=self.default_bind, + tls_cert=None, + tls_key=None, + tls_password=None) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_help_flag(self, _): + options = ['-h', '--help'] + for option in options: + with self.assertRaises(SystemExit): + output = self.invoke_httpd(option) + self.assertStartsWith(output, 'usage: ') + + @mock.patch('http.server.test') + def test_unknown_flag(self, _): + with self.assertRaises(SystemExit): + output = self.invoke_httpd('--unknown-flag') + self.assertStartsWith(output, 'usage: ') + + +class CommandLineRunTimeTestCase(unittest.TestCase): + random_data = os.urandom(1024) + random_file_name = 'random.bin' + tls_cert = certdata_file('ssl_cert.pem') + tls_key = certdata_file('ssl_key.pem') + tls_password = 'somepass' + + def setUp(self): + super().setUp() + with open(self.random_file_name, 'wb') as f: + f.write(self.random_data) + self.addCleanup(os_helper.unlink, self.random_file_name) + self.tls_password_file = tempfile.mktemp() + with open(self.tls_password_file, 'wb') as f: + f.write(self.tls_password.encode()) + self.addCleanup(os_helper.unlink, self.tls_password_file) + + def fetch_file(self, path, allow_self_signed_cert=True) -> bytes: + context = ssl.create_default_context() + if allow_self_signed_cert: + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + req = urllib.request.Request(path, method='GET') + res = urllib.request.urlopen(req, context=context) + return res.read() + + def parse_cli_output(self, output: str) -> tuple[str, str, int]: + try: + matches = re.search(r'\((https?)://([^/:]+):(\d+)/?\)', output) + return matches.group(1), matches.group(2), int(matches.group(3)) + except: + return None, None, None + + def wait_for_server(self, proc, protocol, port, bind, timeout=50) -> bool: + while timeout > 0: + line = proc.stdout.readline() + if not line: + time.sleep(0.1) + timeout -= 1 + continue + _protocol, _host, _port = self.parse_cli_output(line) + if not _protocol or not _host or not _port: + time.sleep(0.1) + timeout -= 1 + continue + if _protocol == protocol and _host == bind and _port == port: + return True + else: + break + return False + + def test_http_client(self): + port = find_unused_port() + bind = '127.0.0.1' + proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind, + bufsize=1, text=True) + self.assertTrue(self.wait_for_server(proc, 'http', port, bind)) + res = self.fetch_file(f'http://{bind}:{port}/{self.random_file_name}') + self.assertEqual(res, self.random_data) + proc.terminate() + kill_python(proc) + + def test_https_client(self): + port = find_unused_port() + bind = '127.0.0.1' + proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind, + '--tls-cert', self.tls_cert, + '--tls-key', self.tls_key, + '--tls-password-file', self.tls_password_file, + bufsize=1, text=True) + self.assertTrue(self.wait_for_server(proc, 'https', port, bind)) + res = self.fetch_file(f'https://{bind}:{port}/{self.random_file_name}') + self.assertEqual(res, self.random_data) + proc.terminate() + kill_python(proc) + + def setUpModule(): unittest.addModuleCleanup(os.chdir, os.getcwd())