Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions generator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,11 @@ def _get_additional_methods(self, class_name: str) -> List[str]:
indent = " " * 4
if class_name == "Position":
return [
"def __eq__(self, o: 'Position') -> Union[bool, 'NotImplemented']:",
"def __eq__(self, o: object) -> bool:",
f"{indent}if not isinstance(o, Position):",
f"{indent}{indent}return NotImplemented",
f"{indent}return (self.line, self.character) == (o.line, o.character)",
"def __gt__(self, o: 'Position') -> Union[bool, 'NotImplemented']:",
"def __gt__(self, o: 'Position') -> bool:",
f"{indent}if not isinstance(o, Position):",
f"{indent}{indent}return NotImplemented",
f"{indent}return (self.line, self.character) > (o.line, o.character)",
Expand All @@ -316,7 +316,7 @@ def _get_additional_methods(self, class_name: str) -> List[str]:
]
if class_name == "Range":
return [
"def __eq__(self, o: 'Range') -> Union[bool, 'NotImplemented']:",
"def __eq__(self, o: object) -> bool:",
f"{indent}if not isinstance(o, Range):",
f"{indent}{indent}return NotImplemented",
f"{indent}return (self.start == o.start) and (self.end == o.end)",
Expand All @@ -325,7 +325,7 @@ def _get_additional_methods(self, class_name: str) -> List[str]:
]
if class_name == "Location":
return [
"def __eq__(self, o:'Location') -> Union[bool, 'NotImplemented']:",
"def __eq__(self, o: object) -> bool:",
f"{indent}if not isinstance(o, Location):",
f"{indent}{indent}return NotImplemented",
f"{indent}return (self.uri == o.uri) and (self.range == o.range)",
Expand Down Expand Up @@ -925,7 +925,7 @@ def _get_utility_code(self, lsp_model: model.LSPModel) -> List[str]:
f"_KEYWORD_CLASSES = [{', '.join(sorted(set(self._keyword_classes)))}]"
]
code_lines += [
"def is_keyword_class(cls) -> bool:",
"def is_keyword_class(cls: type) -> bool:",
' """Returns true if the class has a property that may be python keyword."""',
" return any(cls is c for c in _KEYWORD_CLASSES)",
"",
Expand All @@ -938,7 +938,7 @@ def _get_utility_code(self, lsp_model: model.LSPModel) -> List[str]:
f"_SPECIAL_CLASSES = [{', '.join(sorted(set(self._special_classes)))}]"
]
code_lines += [
"def is_special_class(cls) -> bool:",
"def is_special_class(cls: type) -> bool:",
' """Returns true if the class or its properties require special handling."""',
" return any(cls is c for c in _SPECIAL_CLASSES)",
"",
Expand All @@ -961,7 +961,7 @@ def _get_utility_code(self, lsp_model: model.LSPModel) -> List[str]:
f"_SPECIAL_PROPERTIES = [{', '.join(sorted(set(self._special_properties)))}]"
]
code_lines += [
"def is_special_property(cls, property_name:str) -> bool:",
"def is_special_property(cls: type, property_name:str) -> bool:",
' """Returns true if the class or its properties require special handling.',
" Example:",
" Consider RenameRegistrationOptions",
Expand All @@ -978,7 +978,7 @@ def _get_utility_code(self, lsp_model: model.LSPModel) -> List[str]:
"",
]

code_lines += ["", "ALL_TYPES_MAP = {"]
code_lines += ["", "ALL_TYPES_MAP: Dict[str, type] = {"]
code_lines += sorted([f"'{name}': {name}," for name in set(self._types.keys())])
code_lines += ["}", ""]

Expand Down Expand Up @@ -1061,9 +1061,8 @@ def _generate_hook(

hook_name = f"_{_to_snake_case(property_def.name)}_hook"
indent = " " * 4
code_lines = [
f"def {hook_name}(object_: Any, _: type):",
]
code_lines = []
return_types = []

has_base_type = False
ref_types = []
Expand All @@ -1078,17 +1077,21 @@ def _generate_hook(
return []

if property_def.optional:
return_types.append("None")
code_lines += [
f"{indent}if object_ is None:",
f"{indent*2}return None",
]

if has_base_type:
return_types += ["bool", "int", "str", "float"]
code_lines += [
f"{indent}if isinstance(object_, (bool, int, str, float)):",
f"{indent*2}return object_",
]

return_types += ref_types

if len(ref_types) == 1:
code_lines += [
f"{indent}return converter.structure(object_, {ref_types[0]})",
Expand All @@ -1106,6 +1109,9 @@ def _generate_hook(
f"{indent*2}return converter.structure(object_, {opt})",
]

declaration = f"def {hook_name}(object_: Any, _: type) -> Union[{', '.join(return_types)}]:"
code_lines.insert(0, declaration)

type_name = self._generate_type_name(property_def.type, None, "lsp_types.")
if property_def.optional:
type_name = f"Optional[{type_name}]"
Expand Down
10 changes: 4 additions & 6 deletions lsprotocol/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def resolve_forward_references() -> None:


def get_converter(
converter: Optional[cattrs.Converter] = cattrs.Converter(),
converter: cattrs.Converter = cattrs.Converter(),
) -> cattrs.Converter:
"""Adds cattrs hooks for LSP lsp_types to the given converter."""
resolve_forward_references()
Expand All @@ -39,10 +39,8 @@ def get_converter(
def _register_required_structure_hooks(
converter: cattrs.Converter,
) -> cattrs.Converter:
def _lsp_object_hook(
object_: Any, type_: type
) -> Union[lsp_types.LSPObject, lsp_types.LSPArray, str, int, float, bool, None]:
if not object_:
def _lsp_object_hook(object_: Any, type_: type) -> Any:
if object_ is None:
return object_
else:
for type_ in [str, bool, int, float, list]:
Expand Down Expand Up @@ -165,7 +163,7 @@ def _notebook_filter_hook(


def _register_custom_property_hooks(converter: cattrs.Converter) -> cattrs.Converter:
def _to_camel_case(name: str):
def _to_camel_case(name: str) -> str:
# TODO: when min Python becomes >= 3.9, then update this to:
# `return name.removesuffix("_")`.
new_name = name[:-1] if name.endswith("_") else name
Expand Down
17 changes: 14 additions & 3 deletions lsprotocol/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@
# Licensed under the MIT License.


from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
import attrs

INTEGER_MIN_VALUE = -(2**31)
INTEGER_MAX_VALUE = 2**31 - 1


def integer_validator(instance: Any, attribute: Any, value: Any) -> bool:
def integer_validator(
instance: Any,
attribute: "attrs.Attribute[int]",
value: Any,
) -> bool:
"""Validates that integer value belongs in the range expected by LSP."""
if not isinstance(value, int) or not (
INTEGER_MIN_VALUE <= value <= INTEGER_MAX_VALUE
Expand All @@ -24,7 +31,11 @@ def integer_validator(instance: Any, attribute: Any, value: Any) -> bool:
UINTEGER_MAX_VALUE = 2**31 - 1


def uinteger_validator(instance: Any, attribute: Any, value: Any) -> bool:
def uinteger_validator(
instance: Any,
attribute: "attrs.Attribute[int]",
value: Any,
) -> bool:
"""Validates that unsigned integer value belongs in the range expected by
LSP."""
if not isinstance(value, int) or not (
Expand Down
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,13 @@ exclude = [
"SUPPORT.md",
]


[tool.mypy]
files = "lsprotocol"
show_error_codes = true
strict = true
enable_error_code = [
"ignore-without-code",
"redundant-expr",
"truthy-bool",
]
enable_recursive_aliases = true