Skip to content

Add fixes for type stub generation #828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions shiny/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _server(inputs: Inputs, outputs: Outputs, session: Session):
cast("Tag | TagList", ui), lib_prefix=self.lib_prefix
)

def init_starlette_app(self):
def init_starlette_app(self) -> starlette.applications.Starlette:
routes: list[starlette.routing.BaseRoute] = [
starlette.routing.WebSocketRoute("/websocket/", self._on_connect_cb),
starlette.routing.Route("/", self._on_root_request_cb, methods=["GET"]),
Expand Down Expand Up @@ -400,7 +400,7 @@ def _render_page_from_file(self, file: Path, lib_prefix: str) -> RenderedHTML:
return rendered


def is_uifunc(x: Path | Tag | TagList | Callable[[Request], Tag | TagList]):
def is_uifunc(x: Path | Tag | TagList | Callable[[Request], Tag | TagList]) -> bool:
if (
isinstance(x, Path)
or isinstance(x, Tag)
Expand Down
6 changes: 3 additions & 3 deletions shiny/_namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from contextlib import contextmanager
from contextvars import ContextVar, Token
from typing import Pattern, Union, overload
from typing import Generator, Pattern, Union, overload


class ResolvedId(str):
Expand Down Expand Up @@ -82,7 +82,7 @@ def resolve_id_or_none(id: Id | None) -> ResolvedId | None:
re_valid_id: Pattern[str] = re.compile("^\\.?\\w+$")


def validate_id(id: str):
def validate_id(id: str) -> None:
if not re_valid_id.match(id):
raise ValueError(
f"The string '{id}' is not a valid id; only letters, numbers, and "
Expand All @@ -97,7 +97,7 @@ def validate_id(id: str):


@contextmanager
def namespace_context(id: Id | None):
def namespace_context(id: Id | None) -> Generator[None, None, None]:
namespace = resolve_id(id) if id else Root
token: Token[ResolvedId | None] = _current_namespace.set(namespace)
try:
Expand Down
4 changes: 2 additions & 2 deletions shiny/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import secrets
import socketserver
import tempfile
from typing import Any, Awaitable, Callable, Optional, TypeVar, cast
from typing import Any, Awaitable, Callable, Generator, Optional, TypeVar, cast

from ._typing_extensions import ParamSpec, TypeGuard

Expand Down Expand Up @@ -200,7 +200,7 @@ def private_random_int(min: int, max: int) -> str:


@contextlib.contextmanager
def private_seed():
def private_seed() -> Generator[None, None, None]:
state = random.getstate()
global own_random_state
try:
Expand Down
8 changes: 4 additions & 4 deletions shiny/express/_is_express.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def __init__(self):
super().__init__()
self.found_shiny_express_import = False

def visit_Import(self, node: ast.Import):
def visit_Import(self, node: ast.Import) -> None:
if any(alias.name == "shiny.express" for alias in node.names):
self.found_shiny_express_import = True

def visit_ImportFrom(self, node: ast.ImportFrom):
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module == "shiny.express":
self.found_shiny_express_import = True
elif node.module == "shiny" and any(
Expand All @@ -69,9 +69,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
self.found_shiny_express_import = True

# Visit top-level nodes.
def visit_Module(self, node: ast.Module):
def visit_Module(self, node: ast.Module) -> None:
super().generic_visit(node)

# Don't recurse into any nodes, so the we'll only ever look at top-level nodes.
def generic_visit(self, node: ast.AST):
def generic_visit(self, node: ast.AST) -> None:
pass
4 changes: 2 additions & 2 deletions shiny/express/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import sys
from contextlib import AbstractContextManager
from typing import Callable, TypeVar, cast, overload
from typing import Callable, Generator, TypeVar, cast, overload

from .. import ui
from .._typing_extensions import ParamSpec
Expand Down Expand Up @@ -109,7 +109,7 @@ def suspend_display(


@contextlib.contextmanager
def suspend_display_ctxmgr():
def suspend_display_ctxmgr() -> Generator[None, None, None]:
oldhook = sys.displayhook
sys.displayhook = null_displayhook
try:
Expand Down
4 changes: 2 additions & 2 deletions shiny/express/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,14 @@ def set_result(x: object):
_top_level_recall_context_manager_has_been_replaced = False


def reset_top_level_recall_context_manager():
def reset_top_level_recall_context_manager() -> None:
global _top_level_recall_context_manager
global _top_level_recall_context_manager_has_been_replaced
_top_level_recall_context_manager = RecallContextManager(_DEFAULT_PAGE_FUNCTION)
_top_level_recall_context_manager_has_been_replaced = False


def get_top_level_recall_context_manager():
def get_top_level_recall_context_manager() -> RecallContextManager[Tag]:
return _top_level_recall_context_manager


Expand Down
3 changes: 2 additions & 1 deletion shiny/express/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from pathlib import Path

from .._app import App
from ._run import wrap_express_app
from ._utils import unescape_from_var_name


# If someone requests shiny.express.app:_2f_path_2f_to_2f_app_2e_py, then we will call
# wrap_express_app(Path("/path/to/app.py")) and return the result.
def __getattr__(name: str):
def __getattr__(name: str) -> App:
name = unescape_from_var_name(name)
return wrap_express_app(Path(name))
8 changes: 5 additions & 3 deletions shiny/express/display_decorator/_display_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def unwrap(fn: TFunc) -> TFunc:
display_body_attr = "__display_body__"


def display_body_unwrap_inplace():
def display_body_unwrap_inplace() -> Callable[[TFunc], TFunc]:
"""
Like `display_body`, but far more violent. This will attempt to traverse any
decorators between this one and the function, and then modify the function _in
Expand Down Expand Up @@ -76,7 +76,7 @@ def decorator(fn: TFunc) -> TFunc:
return decorator


def display_body():
def display_body() -> Callable[[TFunc], TFunc]:
def decorator(fn: TFunc) -> TFunc:
if fn.__code__ in code_cache:
fcode = code_cache[fn.__code__]
Expand Down Expand Up @@ -197,7 +197,9 @@ def _transform_function_ast(node: ast.AST) -> ast.AST:
return func_node


def compare_decorated_code_objects(func_ast: ast.FunctionDef):
def compare_decorated_code_objects(
func_ast: ast.FunctionDef,
) -> Callable[[types.CodeType, types.CodeType], bool]:
linenos = [*[x.lineno for x in func_ast.decorator_list], func_ast.lineno]

def comparator(candidate: types.CodeType, target: types.CodeType) -> bool:
Expand Down
30 changes: 17 additions & 13 deletions shiny/express/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from .. import ui
from ..types import MISSING, MISSING_TYPE
from ..ui._accordion import AccordionPanel
from ..ui._layout_columns import BreakpointsUser
from ..ui._navs import NavPanel, NavSet, NavSetCard
from ..ui.css import CssUnit
from . import _run
from ._recall_context import RecallContextManager, wrap_recall_context_manager
Expand Down Expand Up @@ -39,7 +41,7 @@
# ======================================================================================
# Page functions
# ======================================================================================
def set_page(page_fn: RecallContextManager[Tag]):
def set_page(page_fn: RecallContextManager[Tag]) -> None:
"""Set the page function for the current Shiny express app."""
_run.replace_top_level_recall_context_manager(page_fn, force=True)

Expand Down Expand Up @@ -162,7 +164,7 @@ def layout_column_wrap(
gap: Optional[CssUnit] = None,
class_: Optional[str] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
A grid-like, column-first layout

Expand Down Expand Up @@ -252,7 +254,7 @@ def layout_columns(
class_: Optional[str] = None,
height: Optional[CssUnit] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
Create responsive, column-based grid layouts, based on a 12-column grid.

Expand Down Expand Up @@ -346,7 +348,9 @@ def layout_columns(
)


def column(width: int, *, offset: int = 0, **kwargs: TagAttrValue):
def column(
width: int, *, offset: int = 0, **kwargs: TagAttrValue
) -> RecallContextManager[Tag]:
"""
Responsive row-column based layout

Expand Down Expand Up @@ -381,7 +385,7 @@ def column(width: int, *, offset: int = 0, **kwargs: TagAttrValue):
)


def row(**kwargs: TagAttrValue):
def row(**kwargs: TagAttrValue) -> RecallContextManager[Tag]:
"""
Responsive row-column based layout

Expand Down Expand Up @@ -419,7 +423,7 @@ def card(
fill: bool = True,
class_: Optional[str] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
A Bootstrap card component

Expand Down Expand Up @@ -481,7 +485,7 @@ def accordion(
width: Optional[CssUnit] = None,
height: Optional[CssUnit] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
Create a vertically collapsing accordion.

Expand Down Expand Up @@ -537,7 +541,7 @@ def accordion_panel(
value: Optional[str] | MISSING_TYPE = MISSING,
icon: Optional[TagChild] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[AccordionPanel]:
"""
Single accordion panel.

Expand Down Expand Up @@ -583,7 +587,7 @@ def navset(
selected: Optional[str] = None,
header: TagChild = None,
footer: TagChild = None,
):
) -> RecallContextManager[NavSet]:
"""
Render a set of nav items

Expand Down Expand Up @@ -635,7 +639,7 @@ def navset_card(
sidebar: Optional[ui.Sidebar] = None,
header: TagChild = None,
footer: TagChild = None,
):
) -> RecallContextManager[NavSetCard]:
"""
Render a set of nav items inside a card container.

Expand Down Expand Up @@ -687,7 +691,7 @@ def nav_panel(
*,
value: Optional[str] = None,
icon: TagChild = None,
):
) -> RecallContextManager[NavPanel]:
"""
Create a nav item pointing to some internal content.

Expand Down Expand Up @@ -803,7 +807,7 @@ def page_fillable(
title: Optional[str] = None,
lang: Optional[str] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
Creates a fillable page.

Expand Down Expand Up @@ -854,7 +858,7 @@ def page_sidebar(
window_title: str | MISSING_TYPE = MISSING,
lang: Optional[str] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
Create a page with a sidebar and a title.

Expand Down
2 changes: 1 addition & 1 deletion shiny/http_staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class StaticFiles:
def __init__(self, *, directory: str | os.PathLike[str]):
self.dir = pathlib.Path(os.path.realpath(os.path.normpath(directory)))

async def __call__(self, scope: Scope, receive: Receive, send: Send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
raise AssertionError("StaticFiles can't handle non-http request")
path = scope["path"]
Expand Down
2 changes: 1 addition & 1 deletion shiny/input_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _(func: InputHandlerType):

return _

def remove(self, type: str):
def remove(self, type: str) -> None:
del self[type]

def _process_value(self, type: str, value: Any, name: str, session: Session) -> Any:
Expand Down
6 changes: 3 additions & 3 deletions shiny/reactive/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import typing
import warnings
from contextvars import ContextVar
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, TypeVar
from typing import TYPE_CHECKING, Awaitable, Callable, Generator, Optional, TypeVar

from .. import _utils
from .._datastructures import PriorityQueueFIFO
Expand Down Expand Up @@ -188,7 +188,7 @@ def add_pending_flush(self, ctx: Context, priority: int) -> None:
self._pending_flush_queue.put(priority, ctx)

@contextlib.contextmanager
def isolate(self):
def isolate(self) -> Generator[None, None, None]:
token = self._current_context.set(Context())
try:
yield
Expand All @@ -201,7 +201,7 @@ def isolate(self):

@add_example()
@contextlib.contextmanager
def isolate():
def isolate() -> Generator[None, None, None]:
"""
Create a non-reactive scope within a reactive scope.

Expand Down
6 changes: 3 additions & 3 deletions shiny/render/_try_render_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import base64
import io
import warnings
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, cast

from ..types import ImgData, PlotnineFigure
from ._coordmap import get_coordmap, get_coordmap_plotnine

TryPlotResult = Tuple[bool, "ImgData| None"]
TryPlotResult = Tuple[bool, Union[ImgData, None]]


if TYPE_CHECKING:
Expand Down Expand Up @@ -376,7 +376,7 @@ def try_render_plotnine(
#
# One negative consequence of this logic: if the user intentionally set the dpi to
# rcParam * device_pixel_ratio, we're going to ignore it.
def get_desired_dpi_from_fig(fig: Figure):
def get_desired_dpi_from_fig(fig: Figure) -> float:
ppi_out = fig.get_dpi()

if fig.canvas.device_pixel_ratio != 1 and hasattr(fig, "_original_dpi"):
Expand Down
6 changes: 6 additions & 0 deletions shiny/render/transformer/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Generic,
NamedTuple,
Optional,
Type,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -572,6 +573,11 @@ class OutputTransformer(Generic[IT, OT, P]):
* :class:`~shiny.render.transformer.OutputRenderer`
"""

fn: OutputTransformerFn[IT, P, OT]
ValueFn: Type[ValueFn[IT]]
OutputRenderer: Type[OutputRenderer[OT]]
OutputRendererDecorator: Type[OutputRendererDecorator[IT, OT]]

def params(
self,
*args: P.args,
Expand Down
Loading