Skip to content

Commit 3dc6792

Browse files
niknikitabelonogov
authored andcommitted
Add tests
1 parent 74f15c9 commit 3dc6792

File tree

3 files changed

+152
-25
lines changed

3 files changed

+152
-25
lines changed
Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
import types
33
import sys
4+
import functools
5+
from typing import Type, Dict, Any, Tuple, Generator
46
from pathlib import Path
57
from tempfile import TemporaryDirectory
68
from datamodel_code_generator import InputFileType, generate, DataModelType, LiteralType
@@ -9,8 +11,25 @@
911
from contextlib import contextmanager
1012

1113

14+
@functools.lru_cache(maxsize=128)
15+
def _generate_model_code(json_schema_str: str, class_name: str = 'MyModel') -> str:
16+
with TemporaryDirectory() as temp_dir:
17+
temp_file = Path(temp_dir) / "schema.py"
18+
19+
generate(
20+
json_schema_str,
21+
input_file_type=InputFileType.JsonSchema,
22+
input_filename="schema.json",
23+
output=temp_file,
24+
output_model_type=DataModelType.PydanticV2BaseModel,
25+
enum_field_as_literal=LiteralType.All,
26+
class_name=class_name
27+
)
28+
29+
return temp_file.read_text()
30+
1231
@contextmanager
13-
def json_schema_to_pydantic(json_schema: dict, class_name: str = 'MyModel') -> type[BaseModel]:
32+
def json_schema_to_pydantic(json_schema: dict, class_name: str = 'MyModel') -> Generator[Type[BaseModel], None, None]:
1433
"""
1534
Convert a JSON schema to a Pydantic model and provide it as a context manager.
1635
@@ -36,31 +55,26 @@ def json_schema_to_pydantic(json_schema: dict, class_name: str = 'MyModel') -> t
3655
print(instance.model_dump())
3756
```
3857
"""
58+
# Convert the JSON schema dictionary to a JSON string
3959
json_schema_str = json.dumps(json_schema)
40-
41-
with TemporaryDirectory() as temp_dir:
42-
temp_file = Path(temp_dir) / "schema.py"
43-
44-
generate(
45-
json_schema_str,
46-
input_file_type=InputFileType.JsonSchema,
47-
input_filename="schema.json",
48-
output=temp_file,
49-
output_model_type=DataModelType.PydanticV2BaseModel,
50-
enum_field_as_literal=LiteralType.All,
51-
class_name=class_name
52-
)
53-
54-
model_code = temp_file.read_text()
55-
56-
mod = types.ModuleType('dynamic_module')
60+
61+
# Generate Pydantic model code from the JSON schema string
62+
model_code: str = _generate_model_code(json_schema_str, class_name)
63+
64+
# Create a unique module name using the id of the JSON schema string
65+
module_name = f'dynamic_module_{id(json_schema_str)}'
66+
67+
# Create a new module object with the unique name and execute the generated model code in the context of the new module
68+
mod = types.ModuleType(module_name)
5769
exec(model_code, mod.__dict__)
58-
5970
model_class = getattr(mod, class_name)
6071

6172
try:
62-
sys.modules['dynamic_module'] = mod
73+
# Add the new module to sys.modules to make it importable
74+
# This is necessary to avoid Pydantic errors related to undefined models
75+
sys.modules[module_name] = mod
6376
yield model_class
6477
finally:
65-
if 'dynamic_module' in sys.modules:
66-
del sys.modules['dynamic_module']
78+
if module_name in sys.modules:
79+
del sys.modules[module_name]
80+

src/label_studio_sdk/label_interface/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def create_regions(self, data: Dict[str, Union[str, Dict, List[str], List[Dict]]
321321
# 2. we should be less open regarding the payload type and defining the strict typing elsewhere,
322322
# but likely that requires rewriting of how ControlTag.label() is working now
323323
if isinstance(payload, str):
324-
payload = {'label': payload}
324+
payload = {'label': payload, 'text': [payload]}
325325
elif isinstance(payload, list):
326326
if len(payload) > 0:
327327
if isinstance(payload[0], str):

tests/custom/test_interface/test_json_schema.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import json
23
from datetime import datetime, timezone
34
from label_studio_sdk.label_interface.interface import LabelInterface
45
from label_studio_sdk.label_interface.control_tags import ControlTag
@@ -200,7 +201,119 @@ def test_to_json_schema(config, expected_json_schema, input_arg, expected_result
200201
json_schema = interface.to_json_schema()
201202
assert json_schema == expected_json_schema
202203

203-
# convert JSON Schema to Pydantic
204204
with json_schema_to_pydantic(json_schema) as ResponseModel:
205205
instance = ResponseModel(**input_arg)
206-
assert instance.model_dump() == expected_result
206+
assert instance.model_dump() == expected_result
207+
208+
209+
210+
def process_json_schema(json_schema, input_arg, queue):
211+
with json_schema_to_pydantic(json_schema) as ResponseModel:
212+
instance = ResponseModel(**input_arg)
213+
queue.put(instance.model_dump())
214+
215+
def test_concurrent_json_schema_to_pydantic():
216+
import multiprocessing
217+
json_schema = {
218+
"type": "object",
219+
"properties": {
220+
"sentiment": {
221+
"type": "string",
222+
"description": "Choices for doc",
223+
"enum": ["Positive", "Negative", "Neutral"],
224+
}
225+
},
226+
"required": ["sentiment"]
227+
}
228+
input_arg1 = {"sentiment": "Positive"}
229+
input_arg2 = {"sentiment": "Negative"}
230+
231+
queue = multiprocessing.Queue()
232+
233+
p1 = multiprocessing.Process(target=process_json_schema, args=(json_schema, input_arg1, queue))
234+
p2 = multiprocessing.Process(target=process_json_schema, args=(json_schema, input_arg2, queue))
235+
236+
p1.start()
237+
p2.start()
238+
239+
p1.join()
240+
p2.join()
241+
242+
results = [queue.get() for _ in range(2)]
243+
244+
assert {"sentiment": "Positive"} in results
245+
assert {"sentiment": "Negative"} in results
246+
assert len(results) == 2
247+
248+
249+
def process_json_schema_threaded(json_schema, input_arg, results, index):
250+
with json_schema_to_pydantic(json_schema) as ResponseModel:
251+
instance = ResponseModel(**input_arg)
252+
results[index] = instance.model_dump()
253+
254+
def test_concurrent_json_schema_to_pydantic_threaded():
255+
import threading
256+
import time
257+
258+
json_schema = {
259+
"type": "object",
260+
"properties": {
261+
"sentiment": {
262+
"type": "string",
263+
"description": "Choices for doc",
264+
"enum": ["Positive", "Negative", "Neutral"],
265+
}
266+
},
267+
"required": ["sentiment"]
268+
}
269+
input_args = [
270+
{"sentiment": "Positive"},
271+
{"sentiment": "Negative"},
272+
{"sentiment": "Neutral"},
273+
{"sentiment": "Positive"}
274+
]
275+
276+
results = [None] * len(input_args)
277+
threads = []
278+
279+
# Create and start threads
280+
for i, input_arg in enumerate(input_args):
281+
thread = threading.Thread(target=process_json_schema_threaded, args=(json_schema, input_arg, results, i))
282+
threads.append(thread)
283+
thread.start()
284+
285+
# Wait for all threads to complete
286+
for thread in threads:
287+
thread.join()
288+
289+
# Verify results
290+
assert {"sentiment": "Positive"} in results
291+
assert {"sentiment": "Negative"} in results
292+
assert {"sentiment": "Neutral"} in results
293+
assert results.count({"sentiment": "Positive"}) == 2
294+
assert len(results) == 4
295+
assert None not in results
296+
297+
# Verify thread safety by running multiple times
298+
for _ in range(10):
299+
results = [None] * len(input_args)
300+
threads = []
301+
302+
start_time = time.time()
303+
for i, input_arg in enumerate(input_args):
304+
thread = threading.Thread(target=process_json_schema_threaded, args=(json_schema, input_arg, results, i))
305+
threads.append(thread)
306+
thread.start()
307+
308+
for thread in threads:
309+
thread.join()
310+
311+
end_time = time.time()
312+
313+
assert set(result["sentiment"] for result in results) == set(["Positive", "Negative", "Neutral"])
314+
assert results.count({"sentiment": "Positive"}) == 2
315+
assert len(results) == 4
316+
assert None not in results
317+
318+
# Check if execution time is reasonable (adjust as needed)
319+
assert end_time - start_time < 1.0, f"Execution took too long: {end_time - start_time} seconds"

0 commit comments

Comments
 (0)