Skip to content

Commit b545c81

Browse files
authored
Add comparison methods to few types. (#70)
Closes #60
1 parent ec1fe9c commit b545c81

8 files changed

Lines changed: 221 additions & 11 deletions

File tree

generator/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ importlib-resources==5.9.0 \
1818
# via
1919
# -r ./generator/requirements.in
2020
# jsonschema
21-
jsonschema==4.12.1 \
22-
--hash=sha256:05f975aee3f1244a1ea0e018e8ad2672f6ca5fd1a28bc46ffc7d4b3e9896cac4 \
23-
--hash=sha256:c7dd96a88c4ea60bdc8478589ee2d4ea5d73ab235e24d17641ad733dde4e3eb1
21+
jsonschema==4.13.0 \
22+
--hash=sha256:3776512df4f53f74e6e28fe35717b5b223c1756875486984a31bc9165e7fc920 \
23+
--hash=sha256:870a61bb45050b81103faf6a4be00a0a906e06636ffcf0b84f5a2e51faf901ff
2424
# via -r ./generator/requirements.in
2525
pkgutil-resolve-name==1.3.10 \
2626
--hash=sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174 \

generator/utils.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def _reset(self):
139139
self._types: OrderedDict[str, List[str]] = collections.OrderedDict()
140140
self._imports: List[str] = [
141141
"import enum",
142+
"import functools",
142143
"from typing import Any, Dict, List, Optional, Tuple, Union",
143144
"import attrs",
144145
"from . import validators",
@@ -292,6 +293,41 @@ def _has_type(self, type_name: str) -> bool:
292293
type_name = type_name[1:-1]
293294
return type_name in self._types
294295

296+
def _get_additional_methods(self, class_name: str) -> List[str]:
297+
indent = " " * 4
298+
if class_name == "Position":
299+
return [
300+
"def __eq__(self, o: 'Position') -> Union[bool, 'NotImplemented']:",
301+
f"{indent}if not isinstance(o, Position):",
302+
f"{indent}{indent}return NotImplemented",
303+
f"{indent}return (self.line, self.character) == (o.line, o.character)",
304+
"def __gt__(self, o: 'Position') -> Union[bool, 'NotImplemented']:",
305+
f"{indent}if not isinstance(o, Position):",
306+
f"{indent}{indent}return NotImplemented",
307+
f"{indent}return (self.line, self.character) > (o.line, o.character)",
308+
"def __repr__(self) -> str:",
309+
f"{indent}" + "return f'{self.line}:{self.character}'",
310+
]
311+
if class_name == "Range":
312+
return [
313+
"def __eq__(self, o: 'Range') -> Union[bool, 'NotImplemented']:",
314+
f"{indent}if not isinstance(o, Range):",
315+
f"{indent}{indent}return NotImplemented",
316+
f"{indent}return (self.start == o.start) and (self.end == o.end)",
317+
"def __repr__(self) -> str:",
318+
f"{indent}" + "return f'{self.start!r}-{self.end!r}'",
319+
]
320+
if class_name == "Location":
321+
return [
322+
"def __eq__(self, o:'Location') -> Union[bool, 'NotImplemented']:",
323+
f"{indent}if not isinstance(o, Location):",
324+
f"{indent}{indent}return NotImplemented",
325+
f"{indent}return (self.uri == o.uri) and (self.range == o.range)",
326+
"def __repr__(self) -> str:",
327+
f"{indent}" + "return f'{self.uri}:{self.range!r}'",
328+
]
329+
return None
330+
295331
def _add_type_code(self, type_name: str, code: List[str]) -> None:
296332
if not self._has_type(type_name):
297333
self._types[type_name] = code
@@ -529,6 +565,7 @@ def _add_structure(
529565

530566
class_lines = [
531567
"" if class_name == "LSPObject" else "@attrs.define",
568+
"@functools.total_ordering" if class_name == "Position" else "",
532569
f"class {class_name}:",
533570
f'{indent}"""{doc}"""' if struct_def.documentation else "",
534571
f"{indent}# Since: {_sanitize_comment(struct_def.since)}"
@@ -549,10 +586,15 @@ def _add_structure(
549586
properties += copy.deepcopy(d.properties)
550587

551588
code_lines += self._generate_properties(class_name, properties, indent)
589+
methods = self._get_additional_methods(class_name)
590+
552591
# If the class has no properties then add `pass`
553-
if len(properties) == 0:
592+
if len(properties) == 0 and not methods:
554593
code_lines += [f"{indent}pass"]
555594

595+
if methods:
596+
code_lines += [f"{indent}{l}" for l in methods]
597+
556598
# Detect if the class has properties that might be keywords.
557599
self._add_type_code(class_name, code_lines)
558600

@@ -764,9 +806,9 @@ def _add_lsp_method_type(self, lsp_model: model.LSPModel) -> None:
764806
"@enum.unique",
765807
"class MessageDirection(enum.Enum):",
766808
]
767-
code_lines += [
768-
f"{indent}{_capitalized_item_name(m)} = '{m}'" for m in directions
769-
]
809+
code_lines += sorted(
810+
[f"{indent}{_capitalized_item_name(m)} = '{m}'" for m in directions]
811+
)
770812
self._add_type_code("MessageDirection", code_lines)
771813

772814
def _generate_code(self, lsp_model: model.LSPModel) -> None:

lsprotocol/types.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# 3. Run command: `python -m nox --session build_lsp`
99

1010
import enum
11+
import functools
1112
from typing import Any, Dict, List, Optional, Tuple, Union
1213

1314
import attrs
@@ -996,6 +997,14 @@ class Location:
996997

997998
range: "Range" = attrs.field()
998999

1000+
def __eq__(self, o: "Location") -> Union[bool, "NotImplemented"]:
1001+
if not isinstance(o, Location):
1002+
return NotImplemented
1003+
return (self.uri == o.uri) and (self.range == o.range)
1004+
1005+
def __repr__(self) -> str:
1006+
return f"{self.uri}:{self.range!r}"
1007+
9991008

10001009
@attrs.define
10011010
class TextDocumentRegistrationOptions:
@@ -4772,6 +4781,14 @@ class Range:
47724781
end: "Position" = attrs.field()
47734782
"""The range's end position."""
47744783

4784+
def __eq__(self, o: "Range") -> Union[bool, "NotImplemented"]:
4785+
if not isinstance(o, Range):
4786+
return NotImplemented
4787+
return (self.start == o.start) and (self.end == o.end)
4788+
4789+
def __repr__(self) -> str:
4790+
return f"{self.start!r}-{self.end!r}"
4791+
47754792

47764793
@attrs.define
47774794
class WorkspaceFoldersChangeEvent:
@@ -4826,6 +4843,7 @@ class Color:
48264843

48274844

48284845
@attrs.define
4846+
@functools.total_ordering
48294847
class Position:
48304848
"""Position in a text document expressed as zero-based line and character
48314849
offset. Prior to 3.17 the offsets were always based on a UTF-16 string
@@ -4876,6 +4894,19 @@ class Position:
48764894
If the character value is greater than the line length it defaults back to the
48774895
line length."""
48784896

4897+
def __eq__(self, o: "Position") -> Union[bool, "NotImplemented"]:
4898+
if not isinstance(o, Position):
4899+
return NotImplemented
4900+
return (self.line, self.character) == (o.line, o.character)
4901+
4902+
def __gt__(self, o: "Position") -> Union[bool, "NotImplemented"]:
4903+
if not isinstance(o, Position):
4904+
return NotImplemented
4905+
return (self.line, self.character) > (o.line, o.character)
4906+
4907+
def __repr__(self) -> str:
4908+
return f"{self.line}:{self.character}"
4909+
48794910

48804911
@attrs.define
48814912
class SemanticTokensEdit:
@@ -10473,9 +10504,9 @@ class ProgressNotification:
1047310504

1047410505
@enum.unique
1047510506
class MessageDirection(enum.Enum):
10476-
ServerToClient = "serverToClient"
1047710507
Both = "both"
1047810508
ClientToServer = "clientToServer"
10509+
ServerToClient = "serverToClient"
1047910510

1048010511

1048110512
CALL_HIERARCHY_INCOMING_CALLS = "callHierarchy/incomingCalls"

noxfile.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ def _generate_model(session: nox.Session):
3434
session.run("pip", "list")
3535

3636
session.run("python", "-m", "generator", "--output", "./lsprotocol/types.py")
37+
session.run("isort", "--profile", "black", "./lsprotocol/types.py")
38+
session.run("black", "./lsprotocol/types.py")
39+
session.run("docformatter", "--in-place", "./lsprotocol/types.py")
40+
3741
session.run("isort", "--profile", "black", "./lsprotocol")
3842
session.run("black", "./lsprotocol")
3943
session.run("docformatter", "--in-place", "--recursive", "./lsprotocol")

tests/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ iniconfig==1.1.1 \
4242
--hash=sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3 \
4343
--hash=sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32
4444
# via pytest
45-
jsonschema==4.12.1 \
46-
--hash=sha256:05f975aee3f1244a1ea0e018e8ad2672f6ca5fd1a28bc46ffc7d4b3e9896cac4 \
47-
--hash=sha256:c7dd96a88c4ea60bdc8478589ee2d4ea5d73ab235e24d17641ad733dde4e3eb1
45+
jsonschema==4.13.0 \
46+
--hash=sha256:3776512df4f53f74e6e28fe35717b5b223c1756875486984a31bc9165e7fc920 \
47+
--hash=sha256:870a61bb45050b81103faf6a4be00a0a906e06636ffcf0b84f5a2e51faf901ff
4848
# via -r ./tests/requirements.in
4949
packaging==21.3 \
5050
--hash=sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb \

tests/test_location.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import hamcrest
5+
import pytest
6+
7+
from lsprotocol import types as lsp
8+
9+
10+
@pytest.mark.parametrize(
11+
("a", "b", "expected"),
12+
[
13+
(
14+
lsp.Location(
15+
"some_path", lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56))
16+
),
17+
lsp.Location(
18+
"some_path", lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56))
19+
),
20+
True,
21+
),
22+
(
23+
lsp.Location(
24+
"some_path", lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56))
25+
),
26+
lsp.Location(
27+
"some_path2", lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56))
28+
),
29+
False,
30+
),
31+
(
32+
lsp.Location(
33+
"some_path", lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56))
34+
),
35+
lsp.Location(
36+
"some_path", lsp.Range(lsp.Position(1, 23), lsp.Position(8, 91))
37+
),
38+
False,
39+
),
40+
],
41+
)
42+
def test_location_equality(a, b, expected):
43+
hamcrest.assert_that(a == b, hamcrest.is_(expected))
44+
45+
46+
def test_location_repr():
47+
a = lsp.Location("some_path", lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56)))
48+
hamcrest.assert_that(f"{a!r}", hamcrest.is_("some_path:1:23-4:56"))

tests/test_position.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import hamcrest
5+
import pytest
6+
7+
from lsprotocol import types as lsp
8+
9+
10+
@pytest.mark.parametrize(
11+
("a", "b", "comp", "expected"),
12+
[
13+
(lsp.Position(1, 10), lsp.Position(1, 10), "==", True),
14+
(lsp.Position(1, 10), lsp.Position(1, 11), "==", False),
15+
(lsp.Position(1, 10), lsp.Position(1, 11), "!=", True),
16+
(lsp.Position(1, 10), lsp.Position(2, 20), "!=", True),
17+
(lsp.Position(2, 10), lsp.Position(1, 10), ">", True),
18+
(lsp.Position(2, 10), lsp.Position(1, 10), ">=", True),
19+
(lsp.Position(1, 11), lsp.Position(1, 10), ">", True),
20+
(lsp.Position(1, 11), lsp.Position(1, 10), ">=", True),
21+
(lsp.Position(1, 10), lsp.Position(1, 10), ">=", True),
22+
(lsp.Position(1, 10), lsp.Position(2, 10), "<", True),
23+
(lsp.Position(1, 10), lsp.Position(2, 10), "<=", True),
24+
(lsp.Position(1, 10), lsp.Position(1, 10), "<=", True),
25+
(lsp.Position(1, 10), lsp.Position(1, 11), "<", True),
26+
(lsp.Position(1, 10), lsp.Position(1, 11), "<=", True),
27+
],
28+
)
29+
def test_position_comparison(
30+
a: lsp.Position, b: lsp.Position, comp: str, expected: bool
31+
):
32+
if comp == "==":
33+
result = a == b
34+
elif comp == "!=":
35+
result = a != b
36+
elif comp == "<":
37+
result = a < b
38+
elif comp == "<=":
39+
result = a <= b
40+
elif comp == ">":
41+
result = a > b
42+
elif comp == ">=":
43+
result = a >= b
44+
hamcrest.assert_that(result, hamcrest.is_(expected))
45+
46+
47+
def test_position_repr():
48+
p = lsp.Position(1, 23)
49+
hamcrest.assert_that(f"{p!r}", hamcrest.is_("1:23"))

tests/test_range.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import hamcrest
5+
import pytest
6+
7+
from lsprotocol import types as lsp
8+
9+
10+
@pytest.mark.parametrize(
11+
("a", "b", "expected"),
12+
[
13+
(
14+
lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56)),
15+
lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56)),
16+
True,
17+
),
18+
(
19+
lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56)),
20+
lsp.Range(lsp.Position(1, 23), lsp.Position(4, 57)),
21+
False,
22+
),
23+
(
24+
lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56)),
25+
lsp.Range(lsp.Position(1, 23), lsp.Position(7, 56)),
26+
False,
27+
),
28+
],
29+
)
30+
def test_range_equality(a, b, expected):
31+
hamcrest.assert_that(a == b, hamcrest.is_(expected))
32+
33+
34+
def test_range_repr():
35+
a = lsp.Range(lsp.Position(1, 23), lsp.Position(4, 56))
36+
hamcrest.assert_that(f"{a!r}", hamcrest.is_("1:23-4:56"))

0 commit comments

Comments
 (0)