Skip to content

feat(api): (1/n) datasets api clean up #1573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
bc551e6
datasets api
yanxi0830 Mar 11, 2025
8592c2b
precommit
yanxi0830 Mar 11, 2025
0e8a53a
openapi
yanxi0830 Mar 11, 2025
02aa9a1
remove json_schema_type decorator
yanxi0830 Mar 11, 2025
0e47c65
update
yanxi0830 Mar 12, 2025
817331e
precommit
yanxi0830 Mar 12, 2025
0abedd0
comment
yanxi0830 Mar 12, 2025
1d80ec7
upgrade doc
yanxi0830 Mar 12, 2025
31e3409
Merge branch 'main' into pr1573
yanxi0830 Mar 12, 2025
f840018
Merge branch 'main' into pr1573
yanxi0830 Mar 12, 2025
8942071
Merge branch 'main' into pr1573
yanxi0830 Mar 13, 2025
18de4cd
comments
yanxi0830 Mar 13, 2025
a3173e8
update
yanxi0830 Mar 13, 2025
790b2d5
source
yanxi0830 Mar 13, 2025
09039ec
source
yanxi0830 Mar 13, 2025
4cc1958
huggingface obey consistency
yanxi0830 Mar 13, 2025
4f6f0f6
update doc
yanxi0830 Mar 13, 2025
772339b
update doc
yanxi0830 Mar 13, 2025
b4d118f
update doc
yanxi0830 Mar 13, 2025
0df3304
update doc
yanxi0830 Mar 13, 2025
8a6fa41
more purposes
yanxi0830 Mar 13, 2025
8b80a77
docs
yanxi0830 Mar 13, 2025
78ec3d9
Merge branch 'main' into pr1573
yanxi0830 Mar 13, 2025
89885fd
datasetio->datasets
yanxi0830 Mar 13, 2025
a609582
docs
yanxi0830 Mar 13, 2025
7606e49
feat(dataset api): (1.1/n) dataset api implementation fix pre-commit …
yanxi0830 Mar 13, 2025
0e2a13d
Merge branch 'main' into pr1573
yanxi0830 Mar 14, 2025
cba4842
Merge branch 'main' into pr1573
yanxi0830 Mar 14, 2025
c7d741d
Merge branch 'main' into pr1573
yanxi0830 Mar 15, 2025
39f4dfb
feat(api): (1.2/n) datasets.iterrorws pagination api updates (#1656)
yanxi0830 Mar 15, 2025
5cb0ad7
openapi gen + precommit fix
yanxi0830 Mar 15, 2025
72ccdc1
feat(datasets api): (1.3/n) patch OpenAPI gen for datasetio->datasets…
yanxi0830 Mar 15, 2025
2c9d624
feat(dataset api): (1.4/n) fix resolver signature mismatch (#1658)
yanxi0830 Mar 15, 2025
a568bf3
feat(dataset api): (1.5/n) fix dataset registeration (#1659)
yanxi0830 Mar 15, 2025
6f5df08
fix hf url endpoint
yanxi0830 Mar 15, 2025
28b8c1c
scoring fix
yanxi0830 Mar 16, 2025
f2d9332
pre
yanxi0830 Mar 16, 2025
a6fa3aa
feat(dataset api): (1.6/n) fix all iterrows callsites (#1660)
yanxi0830 Mar 16, 2025
5cf7779
fix integeration
yanxi0830 Mar 16, 2025
63f1525
precommit
yanxi0830 Mar 16, 2025
d9264a0
dataaset
yanxi0830 Mar 16, 2025
6a8bd19
next_index -> next_start_index
yanxi0830 Mar 17, 2025
cc48d9e
precommit
yanxi0830 Mar 17, 2025
54a4f41
doc update
yanxi0830 Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
733 changes: 410 additions & 323 deletions docs/_static/llama-stack-spec.html

Large diffs are not rendered by default.

489 changes: 285 additions & 204 deletions docs/_static/llama-stack-spec.yaml

Large diffs are not rendered by default.

2,186 changes: 1,080 additions & 1,106 deletions docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb

Large diffs are not rendered by default.

352 changes: 268 additions & 84 deletions docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb

Large diffs are not rendered by default.

39 changes: 22 additions & 17 deletions docs/openapi_generator/pyopenapi/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def __init__(self, endpoint: type, options: Options) -> None:
)
self.schema_builder = SchemaBuilder(schema_generator)
self.responses = {}

# Create standard error responses
self._create_standard_error_responses()

Expand All @@ -446,7 +446,7 @@ def _create_standard_error_responses(self) -> None:
"""
# Get the Error schema
error_schema = self.schema_builder.classdef_to_ref(Error)

# Create standard error responses
self.responses["BadRequest400"] = Response(
description="The request was invalid or malformed",
Expand All @@ -457,11 +457,11 @@ def _create_standard_error_responses(self) -> None:
"status": 400,
"title": "Bad Request",
"detail": "The request was invalid or malformed",
}
},
)
}
},
)

self.responses["TooManyRequests429"] = Response(
description="The client has sent too many requests in a given amount of time",
content={
Expand All @@ -471,11 +471,11 @@ def _create_standard_error_responses(self) -> None:
"status": 429,
"title": "Too Many Requests",
"detail": "You have exceeded the rate limit. Please try again later.",
}
},
)
}
},
)

self.responses["InternalServerError500"] = Response(
description="The server encountered an unexpected error",
content={
Expand All @@ -485,11 +485,11 @@ def _create_standard_error_responses(self) -> None:
"status": 500,
"title": "Internal Server Error",
"detail": "An unexpected error occurred. Our team has been notified.",
}
},
)
}
},
)

# Add a default error response for any unhandled error cases
self.responses["DefaultError"] = Response(
description="An unexpected error occurred",
Expand All @@ -500,9 +500,9 @@ def _create_standard_error_responses(self) -> None:
"status": 0,
"title": "Error",
"detail": "An unexpected error occurred",
}
},
)
}
},
)

def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
Expand Down Expand Up @@ -547,11 +547,14 @@ def _build_operation(self, op: EndpointOperation) -> Operation:
"SyntheticDataGeneration",
"PostTraining",
"BatchInference",
"Files",
]:
op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)"
print(op.defining_class.__name__)

# TODO (xiyan): temporary fix for datasetio inner impl + datasets api
# if op.defining_class.__name__ in ["DatasetIO"]:
# op.defining_class.__name__ = "Datasets"

doc_string = parse_type(op.func_ref)
doc_params = dict(
(param.name, param.description) for param in doc_string.params.values()
Expand Down Expand Up @@ -598,7 +601,9 @@ def _build_operation(self, op: EndpointOperation) -> Operation:

# data passed in request body as raw bytes cannot have request parameters
if raw_bytes_request_body and op.request_params:
raise ValueError("Cannot have both raw bytes request body and request parameters")
raise ValueError(
"Cannot have both raw bytes request body and request parameters"
)

# data passed in request body as raw bytes
if raw_bytes_request_body:
Expand Down Expand Up @@ -719,7 +724,7 @@ def _build_operation(self, op: EndpointOperation) -> Operation:
responses.update(response_builder.build_response(response_options))

assert len(responses.keys()) > 0, f"No responses found for {op.name}"

# Add standard error response references
if self.options.include_standard_error_responses:
if "400" not in responses:
Expand All @@ -730,7 +735,7 @@ def _build_operation(self, op: EndpointOperation) -> Operation:
responses["500"] = ResponseRef("InternalServerError500")
if "default" not in responses:
responses["default"] = ResponseRef("DefaultError")

if op.event_type is not None:
builder = ContentBuilder(self.schema_builder)
callbacks = {
Expand Down
2 changes: 1 addition & 1 deletion docs/source/distributions/ondevice_distro/android_sdk.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Breaking down the demo app, this section will show the core pieces that are used
### Setup Remote Inferencing
Start a Llama Stack server on localhost. Here is an example of how you can do this using the firework.ai distribution:
```
conda create -n stack-fireworks python=3.10
conda create -n stack-fireworks python=3.10
conda activate stack-fireworks
pip install --no-cache llama-stack==0.1.4
llama stack build --template fireworks --image-type conda
Expand Down
24 changes: 9 additions & 15 deletions docs/source/references/evals_reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,17 @@ pprint(response)
simpleqa_dataset_id = "huggingface::simpleqa"

_ = client.datasets.register(
dataset_id=simpleqa_dataset_id,
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
metadata={
"path": "llamastack/simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
purpose="eval/messages-answer",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id=simpleqa_dataset_id,
)

eval_rows = client.datasetio.get_rows_paginated(
eval_rows = client.datasets.iterrows(
dataset_id=simpleqa_dataset_id,
rows_in_page=5,
limit=5,
)
```

Expand All @@ -143,7 +137,7 @@ client.benchmarks.register(

response = client.eval.evaluate_rows(
benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
input_rows=eval_rows.data,
scoring_functions=["llm-as-judge::405b-simpleqa"],
benchmark_config={
"eval_candidate": {
Expand Down Expand Up @@ -191,7 +185,7 @@ agent_config = {

response = client.eval.evaluate_rows(
benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
input_rows=eval_rows.data,
scoring_functions=["llm-as-judge::405b-simpleqa"],
benchmark_config={
"eval_candidate": {
Expand Down
34 changes: 15 additions & 19 deletions llama_stack/apis/datasetio/datasetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,16 @@


@json_schema_type
class PaginatedRowsResult(BaseModel):
class IterrowsResponse(BaseModel):
"""
A paginated list of rows from a dataset.

:param rows: The rows in the current page.
:param total_count: The total number of rows in the dataset.
:param next_page_token: The token to get the next page of rows.
:param data: The rows in the current page.
:param next_start_index: Index into dataset for the first row in the next page. None if there are no more rows.
"""

# the rows obey the DatasetSchema for the given dataset
rows: List[Dict[str, Any]]
total_count: int
next_page_token: Optional[str] = None
data: List[Dict[str, Any]]
next_start_index: Optional[int] = None


class DatasetStore(Protocol):
Expand All @@ -37,22 +34,21 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore

@webmethod(route="/datasetio/rows", method="GET")
async def get_rows_paginated(
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is not flaky (which means it sometimes works vs. sometimes not) -- I think maybe you just mean "wonkiness" or "sadness"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it actually is the fact that this sometimes work and sometimes do not work if I set the route to /datasets. I'm suspecting it may be due to the way we do topological sort in our resolver

@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
"""Get a paginated list of rows from a dataset.
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
"""Get a paginated list of rows from a dataset. Uses cursor-based pagination.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is using cursor-based pagination because this is just an index we are returning? Maybe just avoid saying this is cursor-based?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm basing the API off of this cursor based pagination spec.

image


:param dataset_id: The ID of the dataset to get the rows from.
:param rows_in_page: The number of rows to get per page.
:param page_token: The token to get the next page of rows.
:param filter_condition: (Optional) A condition to filter the rows by.
:param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get.
"""
...

@webmethod(route="/datasetio/rows", method="POST")
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
Loading