Skip to content

INTPYTHON-380 Add typing #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = []
html_static_path: list[str] = []

# If not "", a "Last updated on:" timestamp is inserted at every page bottom,
# using the given strftime format.
Expand Down Expand Up @@ -182,14 +182,14 @@

# -- Options for LaTeX output --------------------------------------------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you removing the section? what's happening here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was an empty dict, which tripped up the type checker, so I just commented out to use the default value.


latex_elements = {
# The paper size ("letterpaper" or "a4paper").
# "papersize": "letterpaper",
# The font size ("10pt", "11pt" or "12pt").
# "pointsize": "10pt",
# Additional stuff for the LaTeX preamble.
# "preamble": "",
}
# latex_elements = {
# The paper size ("letterpaper" or "a4paper").
# "papersize": "letterpaper",
# The font size ("10pt", "11pt" or "12pt").
# "pointsize": "10pt",
# Additional stuff for the LaTeX preamble.
# "preamble": "",
# }

# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]).
Expand Down Expand Up @@ -228,9 +228,7 @@

# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
("index", "flask-pymongo", "Flask-PyMongo Documentation", ["Dan Crosta"], 1)
]
man_pages = [("index", "flask-pymongo", "Flask-PyMongo Documentation", ["Dan Crosta"], 1)]

# If true, show URL addresses after external links.
# man_show_urls = False
Expand Down
31 changes: 19 additions & 12 deletions examples/wiki/wiki.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any

import markdown2
import markdown2 # type:ignore[import-untyped]
from flask import Flask, redirect, render_template, request, url_for

from flask_pymongo import PyMongo

if TYPE_CHECKING:
from werkzeug.wrappers.response import Response

app = Flask(__name__)
mongo = PyMongo(app, "mongodb://localhost/wiki")

Expand All @@ -15,17 +19,17 @@


@app.route("/", methods=["GET"])
def redirect_to_homepage():
def redirect_to_homepage() -> Response:
return redirect(url_for("show_page", pagepath="HomePage"))


@app.template_filter()
def totitle(value):
def totitle(value: str) -> str:
return " ".join(WIKIPART.findall(value))


@app.template_filter()
def wikify(value):
def wikify(value: str) -> Any:
parts = WIKIWORD.split(value)
for i, part in enumerate(parts):
if WIKIWORD.match(part):
Expand All @@ -36,20 +40,23 @@ def wikify(value):


@app.route("/<path:pagepath>")
def show_page(pagepath):
page = mongo.db.pages.find_one_or_404({"_id": pagepath})
def show_page(pagepath: str) -> str:
assert mongo.db is not None
page: dict[str, Any] = mongo.db.pages.find_one_or_404({"_id": pagepath})
return render_template("page.html", page=page, pagepath=pagepath)


@app.route("/edit/<path:pagepath>", methods=["GET"])
def edit_page(pagepath):
page = mongo.db.pages.find_one_or_404({"_id": pagepath})
def edit_page(pagepath: str) -> str:
assert mongo.db is not None
page: dict[str, Any] = mongo.db.pages.find_one_or_404({"_id": pagepath})
return render_template("edit.html", page=page, pagepath=pagepath)


@app.route("/edit/<path:pagepath>", methods=["POST"])
def save_page(pagepath):
def save_page(pagepath: str) -> Response:
if "cancel" not in request.form:
assert mongo.db is not None
mongo.db.pages.update(
{"_id": pagepath},
{"$set": {"body": request.form["body"]}},
Expand All @@ -60,7 +67,7 @@ def save_page(pagepath):


@app.errorhandler(404)
def new_page(error):
def new_page(error: Any) -> str:
pagepath = request.path.lstrip("/")
if pagepath.startswith("uploads"):
filename = pagepath[len("uploads") :].lstrip("/")
Expand All @@ -69,12 +76,12 @@ def new_page(error):


@app.route("/uploads/<path:filename>")
def get_upload(filename):
def get_upload(filename: str) -> Response:
return mongo.send_file(filename)


@app.route("/uploads/<path:filename>", methods=["POST"])
def save_upload(filename):
def save_upload(filename: str) -> str | Response:
if request.files.get("file"):
mongo.save_file(filename, request.files["file"])
return redirect(url_for("get_upload", filename=filename))
Expand Down
46 changes: 26 additions & 20 deletions flask_pymongo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,22 @@
# POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations

__all__ = ("PyMongo", "ASCENDING", "DESCENDING")
__all__ = ("PyMongo", "ASCENDING", "DESCENDING", "BSONObjectIdConverter", "BSONProvider")

import hashlib
from mimetypes import guess_type
from typing import Any

import pymongo
from flask import abort, current_app, request
from flask import Flask, Response, abort, current_app, request
from gridfs import GridFS, NoFile
from pymongo import uri_parser
from pymongo.driver_info import DriverInfo
from werkzeug.wsgi import wrap_file

# DriverInfo was added in PyMongo 3.7
try:
from pymongo.driver_info import DriverInfo
except ImportError:
DriverInfo = None

from flask_pymongo._version import __version__
from flask_pymongo.helpers import BSONObjectIdConverter, BSONProvider
from flask_pymongo.wrappers import MongoClient
from flask_pymongo.wrappers import Database, MongoClient

DESCENDING = pymongo.DESCENDING
"""Descending sort order."""
Expand All @@ -65,15 +61,16 @@ class PyMongo:

"""

def __init__(self, app=None, uri=None, *args, **kwargs):
self.cx = None
self.db = None
self._json_provider = BSONProvider(app)
def __init__(
self, app: Flask | None = None, uri: str | None = None, *args: Any, **kwargs: Any
) -> None:
self.cx: MongoClient | None = None
self.db: Database | None = None

if app is not None:
self.init_app(app, uri, *args, **kwargs)

def init_app(self, app, uri=None, *args, **kwargs):
def init_app(self, app: Flask, uri: str | None = None, *args: Any, **kwargs: Any) -> None:
"""Initialize this :class:`PyMongo` for use.

Configure a :class:`~pymongo.mongo_client.MongoClient`
Expand Down Expand Up @@ -122,10 +119,12 @@ def init_app(self, app, uri=None, *args, **kwargs):
self.db = self.cx[database_name]

app.url_map.converters["ObjectId"] = BSONObjectIdConverter
app.json = self._json_provider
app.json = BSONProvider(app)

# view helpers
def send_file(self, filename, base="fs", version=-1, cache_for=31536000):
def send_file(
self, filename: str, base: str = "fs", version: int = -1, cache_for: int = 31536000
) -> Response:
"""Respond with a file from GridFS.

Returns an instance of the :attr:`~flask.Flask.response_class`
Expand Down Expand Up @@ -153,6 +152,7 @@ def get_upload(filename):
if not isinstance(cache_for, int):
raise TypeError("'cache_for' must be an integer")

assert self.db is not None, "Please initialize the app before calling send_file!"
storage = GridFS(self.db, base)

try:
Expand Down Expand Up @@ -183,7 +183,14 @@ def get_upload(filename):
response.make_conditional(request)
return response

def save_file(self, filename, fileobj, base="fs", content_type=None, **kwargs):
def save_file(
self,
filename: str,
fileobj: Any,
base: str = "fs",
content_type: str | None = None,
**kwargs: Any,
) -> Any:
"""Save a file-like object to GridFS using the given filename.
Return the "_id" of the created file.

Expand Down Expand Up @@ -211,8 +218,7 @@ def save_upload(filename):
if content_type is None:
content_type, _ = guess_type(filename)

assert self.db is not None, "Please initialize the app before calling save_file!"
storage = GridFS(self.db, base)
id = storage.put(
fileobj, filename=filename, content_type=content_type, **kwargs
)
id = storage.put(fileobj, filename=filename, content_type=content_type, **kwargs)
return id
16 changes: 9 additions & 7 deletions flask_pymongo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

__all__ = ("BSONObjectIdConverter", "BSONProvider")

from typing import Any

from bson import json_util
from bson.errors import InvalidId
from bson.json_util import RELAXED_JSON_OPTIONS
Expand All @@ -35,7 +37,7 @@
from werkzeug.routing import BaseConverter


def _iteritems(obj):
def _iteritems(obj: Any) -> Any:
if hasattr(obj, "iteritems"):
return obj.iteritems()
if hasattr(obj, "items"):
Expand Down Expand Up @@ -65,13 +67,13 @@ def show_task(task_id):

"""

def to_python(self, value):
def to_python(self, value: Any) -> ObjectId:
try:
return ObjectId(value)
except InvalidId:
raise abort(404) from None

def to_url(self, value):
def to_url(self, value: Any) -> str:
return str(value)


Expand All @@ -98,15 +100,15 @@ def json_route(cart_id):
:const:`~bson.json_util.RELAXED_JSON_OPTIONS`.
"""

def __init__(self, app):
def __init__(self, app: Any) -> None:
self._default_kwargs = {"json_options": RELAXED_JSON_OPTIONS}

super().__init__(app)

def dumps(self, obj):
def dumps(self, obj: Any, **kwargs: Any) -> str:
"""Serialize MongoDB object types using :mod:`bson.json_util`."""
return json_util.dumps(obj)

def loads(self, str_obj):
def loads(self, s: str | bytes, **kwargs: Any) -> Any:
"""Deserialize MongoDB object types using :mod:`bson.json_util`."""
return json_util.loads(str_obj)
return json_util.loads(s)
File renamed without changes.
38 changes: 20 additions & 18 deletions flask_pymongo/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
# POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations

from typing import Any

from flask import abort
from pymongo import collection, database, mongo_client


class MongoClient(mongo_client.MongoClient):
class MongoClient(mongo_client.MongoClient[dict[str, Any]]):
"""Wrapper for :class:`~pymongo.mongo_client.MongoClient`.

Returns instances of Flask-PyMongo
Expand All @@ -37,20 +39,20 @@ class MongoClient(mongo_client.MongoClient):

"""

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
attr = super().__getattr__(name)
if isinstance(attr, database.Database):
return Database(self, name)
return attr

def __getitem__(self, item):
attr = super().__getitem__(item)
def __getitem__(self, name: str) -> Any:
attr = super().__getitem__(name)
if isinstance(attr, database.Database):
return Database(self, item)
return Database(self, name)
return attr


class Database(database.Database):
class Database(database.Database[dict[str, Any]]):
"""Wrapper for :class:`~pymongo.database.Database`.

Returns instances of Flask-PyMongo
Expand All @@ -59,37 +61,37 @@ class Database(database.Database):

"""

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
attr = super().__getattr__(name)
if isinstance(attr, collection.Collection):
return Collection(self, name)
return attr

def __getitem__(self, item):
item_ = super().__getitem__(item)
def __getitem__(self, name: str) -> Any:
item_ = super().__getitem__(name)
if isinstance(item_, collection.Collection):
return Collection(self, item)
return Collection(self, name)
return item_


class Collection(collection.Collection):
class Collection(collection.Collection[dict[str, Any]]):
"""Sub-class of PyMongo :class:`~pymongo.collection.Collection` with helpers."""

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
attr = super().__getattr__(name)
if isinstance(attr, collection.Collection):
db = self._Collection__database
return Collection(db, attr.name)
return attr

def __getitem__(self, item):
item_ = super().__getitem__(item)
if isinstance(item_, collection.Collection):
def __getitem__(self, name: str) -> Any:
item = super().__getitem__(name)
if isinstance(item, collection.Collection):
db = self._Collection__database
return Collection(db, item_.name)
return item_
return Collection(db, item.name)
return item

def find_one_or_404(self, *args, **kwargs):
def find_one_or_404(self, *args: Any, **kwargs: Any) -> Any:
"""Find a single document or raise a 404.

This is like :meth:`~pymongo.collection.Collection.find_one`, but
Expand Down
3 changes: 3 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ lint:

docs:
uv run sphinx-build -T -b html docs docs/_build

typing:
uv run mypy --install-types --non-interactive .
Loading
Loading