diff --git a/adafruit_httpserver/exceptions.py b/adafruit_httpserver/exceptions.py new file mode 100644 index 0000000..ca70712 --- /dev/null +++ b/adafruit_httpserver/exceptions.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 Dan Halbert for Adafruit Industries +# +# SPDX-License-Identifier: MIT +""" +`adafruit_httpserver.exceptions` +==================================================== +* Author(s): MichaƂ Pokusa +""" + + +class InvalidPathError(Exception): + """ + Parent class for all path related errors. + """ + + +class ParentDirectoryReferenceError(InvalidPathError): + """ + Path contains ``..``, a reference to the parent directory. + """ + + def __init__(self, path: str) -> None: + """Creates a new ``ParentDirectoryReferenceError`` for the ``path``.""" + super().__init__(f"Parent directory reference in path: {path}") + + +class BackslashInPathError(InvalidPathError): + """ + Backslash ``\\`` in path. + """ + + def __init__(self, path: str) -> None: + """Creates a new ``BackslashInPathError`` for the ``path``.""" + super().__init__(f"Backslash in path: {path}") + + +class ResponseAlreadySentError(Exception): + """ + Another ``HTTPResponse`` has already been sent. There can only be one per ``HTTPRequest``. + """ + + +class FileNotExistsError(Exception): + """ + Raised when a file does not exist. + """ + + def __init__(self, path: str) -> None: + """ + Creates a new ``FileNotExistsError`` for the file at ``path``. + """ + super().__init__(f"File does not exist: {path}") diff --git a/adafruit_httpserver/response.py b/adafruit_httpserver/response.py index 19f69de..0444039 100644 --- a/adafruit_httpserver/response.py +++ b/adafruit_httpserver/response.py @@ -8,7 +8,7 @@ """ try: - from typing import Optional, Dict, Union, Tuple + from typing import Optional, Dict, Union, Tuple, Callable from socket import socket from socketpool import SocketPool except ImportError: @@ -17,12 +17,33 @@ import os from errno import EAGAIN, ECONNRESET +from .exceptions import ( + BackslashInPathError, + FileNotExistsError, + ParentDirectoryReferenceError, + ResponseAlreadySentError, +) from .mime_type import MIMEType from .request import HTTPRequest from .status import HTTPStatus, CommonHTTPStatus from .headers import HTTPHeaders +def _prevent_multiple_send_calls(function: Callable): + """ + Decorator that prevents calling ``send`` or ``send_file`` more than once. + """ + + def wrapper(self: "HTTPResponse", *args, **kwargs): + if self._response_already_sent: # pylint: disable=protected-access + raise ResponseAlreadySentError + + result = function(self, *args, **kwargs) + return result + + return wrapper + + class HTTPResponse: """ Response to a given `HTTPRequest`. Use in `HTTPServer.route` decorator functions. @@ -73,8 +94,8 @@ def route_func(request): """ Defaults to ``text/plain`` if not set. - Can be explicitly provided in the constructor, in `send()` or - implicitly determined from filename in `send_file()`. + Can be explicitly provided in the constructor, in ``send()`` or + implicitly determined from filename in ``send_file()``. Common MIME types are defined in `adafruit_httpserver.mime_type.MIMEType`. """ @@ -94,7 +115,7 @@ def __init__( # pylint: disable=too-many-arguments Sets `status`, ``headers`` and `http_version` and optionally default ``content_type``. - To send the response, call `send` or `send_file`. + To send the response, call ``send`` or ``send_file``. For chunked response use ``with HTTPRequest(request, content_type=..., chunked=True) as r:`` and `send_chunk`. """ @@ -115,7 +136,7 @@ def _send_headers( ) -> None: """ Sends headers. - Implicitly called by `send` and `send_file` and in + Implicitly called by ``send`` and ``send_file`` and in ``with HTTPResponse(request, chunked=True) as response:`` context manager. """ headers = self.headers.copy() @@ -141,6 +162,7 @@ def _send_headers( self.request.connection, response_message_header.encode("utf-8") ) + @_prevent_multiple_send_calls def send( self, body: str = "", @@ -152,8 +174,6 @@ def send( Should be called **only once** per response. """ - if self._response_already_sent: - raise RuntimeError("Response was already sent") if getattr(body, "encode", None): encoded_response_message_body = body.encode("utf-8") @@ -167,12 +187,41 @@ def send( self._send_bytes(self.request.connection, encoded_response_message_body) self._response_already_sent = True - def send_file( + @staticmethod + def _check_file_path_is_valid(file_path: str) -> bool: + """ + Checks if ``file_path`` is valid. + If not raises error corresponding to the problem. + """ + + # Check for backslashes + if "\\" in file_path: # pylint: disable=anomalous-backslash-in-string + raise BackslashInPathError(file_path) + + # Check each component of the path for parent directory references + for part in file_path.split("/"): + if part == "..": + raise ParentDirectoryReferenceError(file_path) + + @staticmethod + def _get_file_length(file_path: str) -> int: + """ + Tries to get the length of the file at ``file_path``. + Raises ``FileNotExistsError`` if file does not exist. + """ + try: + return os.stat(file_path)[6] + except OSError: + raise FileNotExistsError(file_path) # pylint: disable=raise-missing-from + + @_prevent_multiple_send_calls + def send_file( # pylint: disable=too-many-arguments self, filename: str = "index.html", root_path: str = "./", buffer_size: int = 1024, head_only: bool = False, + safe: bool = True, ) -> None: """ Send response with content of ``filename`` located in ``root_path``. @@ -181,17 +230,18 @@ def send_file( Should be called **only once** per response. """ - if self._response_already_sent: - raise RuntimeError("Response was already sent") + + if safe: + self._check_file_path_is_valid(filename) if not root_path.endswith("/"): root_path += "/" - try: - file_length = os.stat(root_path + filename)[6] - except OSError: - # If the file doesn't exist, return 404. - HTTPResponse(self.request, status=CommonHTTPStatus.NOT_FOUND_404).send() - return + if filename.startswith("/"): + filename = filename[1:] + + full_file_path = root_path + filename + + file_length = self._get_file_length(full_file_path) self._send_headers( content_type=MIMEType.from_file_name(filename), @@ -199,7 +249,7 @@ def send_file( ) if not head_only: - with open(root_path + filename, "rb") as file: + with open(full_file_path, "rb") as file: while bytes_read := file.read(buffer_size): self._send_bytes(self.request.connection, bytes_read) self._response_already_sent = True diff --git a/adafruit_httpserver/server.py b/adafruit_httpserver/server.py index 8e86e82..8662403 100644 --- a/adafruit_httpserver/server.py +++ b/adafruit_httpserver/server.py @@ -16,6 +16,7 @@ from errno import EAGAIN, ECONNRESET, ETIMEDOUT +from .exceptions import FileNotExistsError, InvalidPathError from .methods import HTTPMethod from .request import HTTPRequest from .response import HTTPResponse @@ -26,18 +27,19 @@ class HTTPServer: """A basic socket-based HTTP server.""" - def __init__(self, socket_source: Protocol) -> None: + def __init__(self, socket_source: Protocol, root_path: str) -> None: """Create a server, and get it ready to run. :param socket: An object that is a source of sockets. This could be a `socketpool` in CircuitPython or the `socket` module in CPython. + :param str root_path: Root directory to serve files from """ self._buffer = bytearray(1024) self._timeout = 1 self.routes = _HTTPRoutes() self._socket_source = socket_source self._sock = None - self.root_path = "/" + self.root_path = root_path def route(self, path: str, method: HTTPMethod = HTTPMethod.GET) -> Callable: """ @@ -63,14 +65,13 @@ def route_decorator(func: Callable) -> Callable: return route_decorator - def serve_forever(self, host: str, port: int = 80, root_path: str = "") -> None: + def serve_forever(self, host: str, port: int = 80) -> None: """Wait for HTTP requests at the given host and port. Does not return. :param str host: host name or IP address :param int port: port - :param str root_path: root directory to serve files from """ - self.start(host, port, root_path) + self.start(host, port) while True: try: @@ -78,17 +79,14 @@ def serve_forever(self, host: str, port: int = 80, root_path: str = "") -> None: except OSError: continue - def start(self, host: str, port: int = 80, root_path: str = "") -> None: + def start(self, host: str, port: int = 80) -> None: """ Start the HTTP server at the given host and port. Requires calling poll() in a while loop to handle incoming requests. :param str host: host name or IP address :param int port: port - :param str root_path: root directory to serve files from """ - self.root_path = root_path - self._sock = self._socket_source.socket( self._socket_source.AF_INET, self._socket_source.SOCK_STREAM ) @@ -158,38 +156,50 @@ def poll(self): conn, received_body_bytes, content_length ) + # Find a handler for the route handler = self.routes.find_handler( _HTTPRoute(request.path, request.method) ) - # If a handler for route exists and is callable, call it. - if handler is not None and callable(handler): - handler(request) - - # If no handler exists and request method is GET, try to serve a file. - elif handler is None and request.method in ( - HTTPMethod.GET, - HTTPMethod.HEAD, - ): - filename = "index.html" if request.path == "/" else request.path - HTTPResponse(request).send_file( - filename=filename, - root_path=self.root_path, - buffer_size=self.request_buffer_size, - head_only=(request.method == HTTPMethod.HEAD), + try: + # If a handler for route exists and is callable, call it. + if handler is not None and callable(handler): + handler(request) + + # If no handler exists and request method is GET or HEAD, try to serve a file. + elif handler is None and request.method in ( + HTTPMethod.GET, + HTTPMethod.HEAD, + ): + filename = "index.html" if request.path == "/" else request.path + HTTPResponse(request).send_file( + filename=filename, + root_path=self.root_path, + buffer_size=self.request_buffer_size, + head_only=(request.method == HTTPMethod.HEAD), + ) + else: + HTTPResponse( + request, status=CommonHTTPStatus.BAD_REQUEST_400 + ).send() + + except InvalidPathError as error: + HTTPResponse(request, status=CommonHTTPStatus.FORBIDDEN_403).send( + str(error) ) - else: - HTTPResponse( - request, status=CommonHTTPStatus.BAD_REQUEST_400 - ).send() - - except OSError as ex: - # handle EAGAIN and ECONNRESET - if ex.errno == EAGAIN: - # there is no data available right now, try again later. + + except FileNotExistsError as error: + HTTPResponse(request, status=CommonHTTPStatus.NOT_FOUND_404).send( + str(error) + ) + + except OSError as error: + # Handle EAGAIN and ECONNRESET + if error.errno == EAGAIN: + # There is no data available right now, try again later. return - if ex.errno == ECONNRESET: - # connection reset by peer, try again later. + if error.errno == ECONNRESET: + # Connection reset by peer, try again later. return raise @@ -204,7 +214,7 @@ def request_buffer_size(self) -> int: Example:: - server = HTTPServer(pool) + server = HTTPServer(pool, "/static") server.request_buffer_size = 2048 server.serve_forever(str(wifi.radio.ipv4_address)) @@ -226,7 +236,7 @@ def socket_timeout(self) -> int: Example:: - server = HTTPServer(pool) + server = HTTPServer(pool, "/static") server.socket_timeout = 3 server.serve_forever(str(wifi.radio.ipv4_address)) diff --git a/adafruit_httpserver/status.py b/adafruit_httpserver/status.py index d32538c..8a7b198 100644 --- a/adafruit_httpserver/status.py +++ b/adafruit_httpserver/status.py @@ -39,6 +39,9 @@ class CommonHTTPStatus(HTTPStatus): # pylint: disable=too-few-public-methods BAD_REQUEST_400 = HTTPStatus(400, "Bad Request") """400 Bad Request""" + FORBIDDEN_403 = HTTPStatus(403, "Forbidden") + """403 Forbidden""" + NOT_FOUND_404 = HTTPStatus(404, "Not Found") """404 Not Found""" diff --git a/docs/api.rst b/docs/api.rst index 4615507..64bb534 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -27,3 +27,6 @@ .. automodule:: adafruit_httpserver.mime_type :members: + +.. automodule:: adafruit_httpserver.exceptions + :members: diff --git a/examples/httpserver_chunked.py b/examples/httpserver_chunked.py index ae519ec..ed67fc6 100644 --- a/examples/httpserver_chunked.py +++ b/examples/httpserver_chunked.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Unlicense -import secrets # pylint: disable=no-name-in-module +import os import socketpool import wifi @@ -12,14 +12,15 @@ from adafruit_httpserver.server import HTTPServer -ssid, password = secrets.WIFI_SSID, secrets.WIFI_PASSWORD # pylint: disable=no-member +ssid = os.getenv("WIFI_SSID") +password = os.getenv("WIFI_PASSWORD") print("Connecting to", ssid) wifi.radio.connect(ssid, password) print("Connected to", ssid) pool = socketpool.SocketPool(wifi.radio) -server = HTTPServer(pool) +server = HTTPServer(pool, "/static") @server.route("/chunked") diff --git a/examples/httpserver_cpu_information.py b/examples/httpserver_cpu_information.py index cf3d13b..41d7a05 100644 --- a/examples/httpserver_cpu_information.py +++ b/examples/httpserver_cpu_information.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Unlicense -import secrets # pylint: disable=no-name-in-module +import os import json import microcontroller @@ -15,14 +15,15 @@ from adafruit_httpserver.server import HTTPServer -ssid, password = secrets.WIFI_SSID, secrets.WIFI_PASSWORD # pylint: disable=no-member +ssid = os.getenv("WIFI_SSID") +password = os.getenv("WIFI_PASSWORD") print("Connecting to", ssid) wifi.radio.connect(ssid, password) print("Connected to", ssid) pool = socketpool.SocketPool(wifi.radio) -server = HTTPServer(pool) +server = HTTPServer(pool, "/static") @server.route("/cpu-information") diff --git a/examples/httpserver_mdns.py b/examples/httpserver_mdns.py index d2228c9..bebdc2a 100644 --- a/examples/httpserver_mdns.py +++ b/examples/httpserver_mdns.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Unlicense -import secrets # pylint: disable=no-name-in-module +import os import mdns import socketpool @@ -14,7 +14,8 @@ from adafruit_httpserver.server import HTTPServer -ssid, password = secrets.WIFI_SSID, secrets.WIFI_PASSWORD # pylint: disable=no-member +ssid = os.getenv("WIFI_SSID") +password = os.getenv("WIFI_PASSWORD") print("Connecting to", ssid) wifi.radio.connect(ssid, password) @@ -25,7 +26,7 @@ mdns_server.advertise_service(service_type="_http", protocol="_tcp", port=80) pool = socketpool.SocketPool(wifi.radio) -server = HTTPServer(pool) +server = HTTPServer(pool, "/static") @server.route("/") diff --git a/examples/httpserver_neopixel.py b/examples/httpserver_neopixel.py index 814a7af..baff3de 100644 --- a/examples/httpserver_neopixel.py +++ b/examples/httpserver_neopixel.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Unlicense -import secrets # pylint: disable=no-name-in-module +import os import board import neopixel @@ -15,14 +15,15 @@ from adafruit_httpserver.server import HTTPServer -ssid, password = secrets.WIFI_SSID, secrets.WIFI_PASSWORD # pylint: disable=no-member +ssid = os.getenv("WIFI_SSID") +password = os.getenv("WIFI_PASSWORD") print("Connecting to", ssid) wifi.radio.connect(ssid, password) print("Connected to", ssid) pool = socketpool.SocketPool(wifi.radio) -server = HTTPServer(pool) +server = HTTPServer(pool, "/static") pixel = neopixel.NeoPixel(board.NEOPIXEL, 1) diff --git a/examples/httpserver_simple_poll.py b/examples/httpserver_simple_poll.py index db876c4..1ed5027 100644 --- a/examples/httpserver_simple_poll.py +++ b/examples/httpserver_simple_poll.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Unlicense -import secrets # pylint: disable=no-name-in-module +import os import socketpool import wifi @@ -13,14 +13,15 @@ from adafruit_httpserver.server import HTTPServer -ssid, password = secrets.WIFI_SSID, secrets.WIFI_PASSWORD # pylint: disable=no-member +ssid = os.getenv("WIFI_SSID") +password = os.getenv("WIFI_PASSWORD") print("Connecting to", ssid) wifi.radio.connect(ssid, password) print("Connected to", ssid) pool = socketpool.SocketPool(wifi.radio) -server = HTTPServer(pool) +server = HTTPServer(pool, "/static") @server.route("/") diff --git a/examples/httpserver_simple_serve.py b/examples/httpserver_simple_serve.py index 632c234..226d8f2 100644 --- a/examples/httpserver_simple_serve.py +++ b/examples/httpserver_simple_serve.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Unlicense -import secrets # pylint: disable=no-name-in-module +import os import socketpool import wifi @@ -13,14 +13,15 @@ from adafruit_httpserver.server import HTTPServer -ssid, password = secrets.WIFI_SSID, secrets.WIFI_PASSWORD # pylint: disable=no-member +ssid = os.getenv("WIFI_SSID") +password = os.getenv("WIFI_PASSWORD") print("Connecting to", ssid) wifi.radio.connect(ssid, password) print("Connected to", ssid) pool = socketpool.SocketPool(wifi.radio) -server = HTTPServer(pool) +server = HTTPServer(pool, "/static") @server.route("/") diff --git a/examples/httpserver_url_parameters.py b/examples/httpserver_url_parameters.py index 2f95163..22f9f3b 100644 --- a/examples/httpserver_url_parameters.py +++ b/examples/httpserver_url_parameters.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Unlicense -import secrets # pylint: disable=no-name-in-module +import os import socketpool import wifi @@ -13,14 +13,15 @@ from adafruit_httpserver.server import HTTPServer -ssid, password = secrets.WIFI_SSID, secrets.WIFI_PASSWORD # pylint: disable=no-member +ssid = os.getenv("WIFI_SSID") +password = os.getenv("WIFI_PASSWORD") print("Connecting to", ssid) wifi.radio.connect(ssid, password) print("Connected to", ssid) pool = socketpool.SocketPool(wifi.radio) -server = HTTPServer(pool) +server = HTTPServer(pool, "/static") class Device: