-
Notifications
You must be signed in to change notification settings - Fork 231
devstral tool parser for tool calling #3851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
|
|
||
| #include <memory> | ||
| #include <string> | ||
| #include <utility> | ||
| #include <openvino/genai/generation_config.hpp> | ||
|
|
||
| #include "generation_config_builder.hpp" | ||
|
|
||
| namespace ovms { | ||
|
|
||
| void DevstralGenerationConfigBuilder::parseConfigFromRequest(const OpenAIChatCompletionsRequest& request) { | ||
| // Call the base class method to fill in common configuration | ||
| BaseGenerationConfigBuilder::parseConfigFromRequest(request); | ||
|
|
||
| // For now the only specific part is related to tools, so if there are no tools provided in the request | ||
| // we can exit early | ||
| if (request.toolNameSchemaMap.empty()) { | ||
| return; | ||
| } | ||
|
|
||
| if (enableToolGuidedGeneration || request.toolChoice == "required") { | ||
| // Set tool guided generation config specific to Devstral model | ||
| auto triggeredTags = std::make_shared<ov::genai::StructuredOutputConfig::TriggeredTags>(); | ||
| triggeredTags->triggers.push_back("[TOOL_CALLS]"); | ||
|
|
||
| for (const auto& [toolName, toolSchemaWrapper] : request.toolNameSchemaMap) { | ||
| const auto& toolSchema = toolSchemaWrapper.stringRepr; | ||
| ov::genai::StructuredOutputConfig::Tag tagItem; | ||
| tagItem.begin = "[TOOL_CALLS]" + toolName + "[ARGS]"; | ||
| tagItem.end = "</s>"; | ||
| tagItem.content = ov::genai::StructuredOutputConfig::JSONSchema(toolSchema); | ||
| triggeredTags->tags.push_back(tagItem); | ||
| } | ||
| if (request.toolChoice == "required") { | ||
| triggeredTags->at_least_one = true; | ||
| } | ||
| ov::genai::StructuredOutputConfig::StructuralTag structuralTag = triggeredTags; | ||
| setStructuralTagsConfig(structuralTag); | ||
| } | ||
| } | ||
|
|
||
| } // namespace ovms |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
| #pragma once | ||
| #include "../base_generation_config_builder.hpp" | ||
|
|
||
| namespace ovms { | ||
|
|
||
| /* | ||
| * Phi4GenerationConfigBuilder extends BaseGenerationConfigBuilder to provide specific configuration for Phi-4 model. | ||
| * It overrides the parseConfigFromRequest method to set tool guided generation config. | ||
| */ | ||
| class DevstralGenerationConfigBuilder : public BaseGenerationConfigBuilder { | ||
| public: | ||
| DevstralGenerationConfigBuilder() = delete; | ||
| explicit DevstralGenerationConfigBuilder(const ov::genai::GenerationConfig& baseConfig, bool enableToolGuidedGeneration, DecodingMethod decodingMethod) : | ||
| BaseGenerationConfigBuilder(baseConfig, enableToolGuidedGeneration, decodingMethod) {} | ||
|
|
||
| void parseConfigFromRequest(const OpenAIChatCompletionsRequest& request) override; | ||
| }; | ||
| } // namespace ovms |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,163 @@ | ||||||||||
| //***************************************************************************** | ||||||||||
| // Copyright 2025 Intel Corporation | ||||||||||
| // | ||||||||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||
| // you may not use this file except in compliance with the License. | ||||||||||
| // You may obtain a copy of the License at | ||||||||||
| // | ||||||||||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||
| // | ||||||||||
| // Unless required by applicable law or agreed to in writing, software | ||||||||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
| // See the License for the specific language governing permissions and | ||||||||||
| // limitations under the License. | ||||||||||
| //***************************************************************************** | ||||||||||
|
|
||||||||||
| #include <openvino/genai/tokenizer.hpp> | ||||||||||
| #include <string> | ||||||||||
| #include <vector> | ||||||||||
| #include <regex> | ||||||||||
|
|
||||||||||
| #include "src/port/rapidjson_document.hpp" | ||||||||||
|
|
||||||||||
| #include "../../../logging.hpp" | ||||||||||
| #include "tool_parser.hpp" | ||||||||||
| #include "../utils.hpp" | ||||||||||
| #include "src/stringutils.hpp" | ||||||||||
|
|
||||||||||
| namespace ovms { | ||||||||||
|
|
||||||||||
| void DevstralToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) { | ||||||||||
| std::vector<std::string> tools; | ||||||||||
| // Parser will consume entire model output only if the first generated token is the beginning of tools token. | ||||||||||
| // expected format: [TOOL_CALLS]tool_name[ARGS]{"arg1": "value1", ...} | ||||||||||
| if (parsedOutput.content.empty() || generatedTokens.size() <= 0) { | ||||||||||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No content to parse for tool calls"); | ||||||||||
| return; | ||||||||||
| } | ||||||||||
| size_t firstToolTokenIndex; | ||||||||||
| auto it = std::find(generatedTokens.begin(), generatedTokens.end(), this->botTokenId); | ||||||||||
| if (it != generatedTokens.end()) { | ||||||||||
| firstToolTokenIndex = std::distance(generatedTokens.begin(), it); | ||||||||||
| } else { | ||||||||||
| return; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| size_t firstArgsTokenIndex; | ||||||||||
| auto it_args = std::find(generatedTokens.begin() + firstToolTokenIndex, generatedTokens.end(), this->argsTokenId); | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please keep the naming convention
Suggested change
|
||||||||||
| if (it_args != generatedTokens.end()) { | ||||||||||
| firstArgsTokenIndex = std::distance(generatedTokens.begin(), it_args); | ||||||||||
| } else { | ||||||||||
| return; | ||||||||||
| } | ||||||||||
| if (firstToolTokenIndex > firstArgsTokenIndex) { | ||||||||||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "First tool token index is greater than first args token index."); | ||||||||||
| return; | ||||||||||
| } | ||||||||||
| std::vector<int64_t> tool_name_tokens(generatedTokens.begin() + (firstToolTokenIndex + 1), generatedTokens.begin() + (firstArgsTokenIndex)); | ||||||||||
| std::vector<int64_t> arguments_tokens(generatedTokens.begin() + (firstArgsTokenIndex + 1), generatedTokens.end()); | ||||||||||
|
Comment on lines
+58
to
+59
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
|
||||||||||
| ToolCall toolCall; | ||||||||||
| std::string tool_name = tokenizer.decode(tool_name_tokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| if (this->toolSchemas.find(tool_name) == this->toolSchemas.end()) { | ||||||||||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool name '{}' not valid.", tool_name); | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is behavior we havent implemented in other parsers, is it really worth return early? if we return function name that is not part of the toolschemas spec, we might be able to debug it in bfcl
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not in line with current behavior in other parsers. I wouldn't do that check if it's only for this parser. Either drop it or create a task for alignment of other parsers. |
||||||||||
| return; | ||||||||||
| } | ||||||||||
| std::string arguments = tokenizer.decode(arguments_tokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); | ||||||||||
|
|
||||||||||
| toolCall.name = tool_name; | ||||||||||
| toolCall.arguments = arguments; | ||||||||||
| toolCall.id = generateRandomId(); // Generate a random ID for the tool call | ||||||||||
| parsedOutput.toolCalls.push_back(toolCall); | ||||||||||
|
|
||||||||||
| // get subset of generatedTokens starting from begin() to firstArgsTokenIndex | ||||||||||
| std::vector<int64_t> content_tokens; | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| if (firstToolTokenIndex > 0) { | ||||||||||
| content_tokens = std::vector<int64_t>(generatedTokens.begin(), generatedTokens.begin() + firstToolTokenIndex); | ||||||||||
| parsedOutput.content = tokenizer.decode(content_tokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); // Return only the content till tool call | ||||||||||
| } else { | ||||||||||
| parsedOutput.content = ""; | ||||||||||
| } | ||||||||||
| return; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| std::optional<rapidjson::Document> DevstralToolParser::sendFullDelta(ToolCall& toolCall) { | ||||||||||
| rapidjson::Document argsDelta; | ||||||||||
| argsDelta.Parse(toolCall.arguments.c_str()); | ||||||||||
| rapidjson::Document argumentsWrapper; | ||||||||||
| argumentsWrapper.SetObject(); | ||||||||||
| rapidjson::Document::AllocatorType& allocator = argumentsWrapper.GetAllocator(); | ||||||||||
| // now we need to add string toolCall.arguments to argumentsWrapper under "arguments" key | ||||||||||
| rapidjson::Value toolCallsString(rapidjson::kStringType); | ||||||||||
| toolCallsString.SetString(toolCall.arguments.c_str(), allocator); | ||||||||||
| argumentsWrapper.AddMember("arguments", toolCallsString, allocator); | ||||||||||
| auto currentDelta = wrapDelta(argumentsWrapper, this->toolCallIndex); | ||||||||||
| return currentDelta; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| std::optional<rapidjson::Document> DevstralToolParser::parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) { | ||||||||||
| /* | ||||||||||
| Devstral [TOOL_CALL]tool_name[ARGS]arguments[</s>] | ||||||||||
| It does not support parallel tool calls, so tool calls are always in sequence. | ||||||||||
|
|
||||||||||
| We have three processing states: | ||||||||||
| AWAITING_START_TAG, | ||||||||||
| AWAITING_ARGS_TAG, | ||||||||||
| PROCESSING_ARGS | ||||||||||
|
|
||||||||||
| We store the history of chunks in streamContent string. After state changes are detected, we clear the streamContent to only keep unprocessed part. | ||||||||||
| */ | ||||||||||
|
|
||||||||||
| this->streamContent += chunk; | ||||||||||
| if (this->internalState == AWAITING_START_TAG) { | ||||||||||
| size_t pos = chunk.find("[TOOL_CALLS]"); | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should look up on
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. first chunk starts with [TOOL_CALLS] and it is never split as it is one token |
||||||||||
| if (pos != std::string::npos) { | ||||||||||
| this->internalState = AWAITING_ARGS_TAG; | ||||||||||
| this->toolCallIndex++; | ||||||||||
| if (pos == 0) { | ||||||||||
| this->streamContent.clear(); | ||||||||||
| } else { | ||||||||||
| this->streamContent = this->streamContent.substr(pos + 13); // "[TOOLS_CALLS]" length is 13 | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should avoid magic numbers, this way if we change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can look up how Adrian handles that in qwen coder |
||||||||||
| } | ||||||||||
| } else { | ||||||||||
| return std::nullopt; | ||||||||||
| } | ||||||||||
| } | ||||||||||
| if (this->internalState == AWAITING_ARGS_TAG) { | ||||||||||
| // check if [ARGS] tag is present in the chunk and update state accordingly | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. technically we check streamContent but it will be the case only if [ARGS] is added in the chunk. Otherwise it would be different state |
||||||||||
| size_t pos = this->streamContent.find("[ARGS]"); | ||||||||||
| if (pos != std::string::npos) { | ||||||||||
| this->internalState = PROCESSING_ARGS; | ||||||||||
| this->toolName = this->streamContent.substr(0, pos); | ||||||||||
| if (this->toolSchemas.find(this->toolName) == this->toolSchemas.end()) { | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for the unary part - this check is unique to this parser and I don't think it's a good idea to have different behavior for different parsers. Either remove or create a task for alignment of other parsers. |
||||||||||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool name '{}' not valid.", this->toolName); | ||||||||||
| return std::nullopt; | ||||||||||
| } | ||||||||||
| this->streamContent = this->streamContent.substr(pos + 6); // "[ARGS]" length is 6 | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Magic number |
||||||||||
| return wrapFirstDelta(this->toolName, this->toolCallIndex); | ||||||||||
| } else { | ||||||||||
| return std::nullopt; | ||||||||||
| } | ||||||||||
| } | ||||||||||
| if (finishReason != ov::genai::GenerationFinishReason::NONE) { | ||||||||||
| size_t end_pos = this->streamContent.find("</s>"); | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this token? If it has some significant value for the parsing it should be a member of the parser class like args and tool calls token. Also:
Suggested change
|
||||||||||
| std::string arguments; | ||||||||||
| if (end_pos != std::string::npos) { | ||||||||||
| arguments = this->streamContent.substr(0, end_pos); | ||||||||||
| } else { | ||||||||||
| arguments = this->streamContent; | ||||||||||
| } | ||||||||||
| if (!arguments.empty()) { | ||||||||||
| ToolCall toolCall; | ||||||||||
| toolCall.arguments = arguments; | ||||||||||
| toolCall.name = this->toolName; | ||||||||||
| return sendFullDelta(toolCall); | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldntg we stream partial function argument chunks? if i understand correctly you send full delta at the end of generation
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already accepted such approach for qwen3 coder, so I suppose we can have it in other parsers as well unless there are specific requirements for "real" streaming. |
||||||||||
| } else { | ||||||||||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No valid arguments found in streamContent."); | ||||||||||
| return std::nullopt; | ||||||||||
| } | ||||||||||
| } | ||||||||||
| return std::nullopt; | ||||||||||
| } | ||||||||||
| } // namespace ovms | ||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
| #pragma once | ||
|
|
||
| #include <openvino/genai/tokenizer.hpp> | ||
| #include <optional> | ||
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include "src/port/rapidjson_document.hpp" | ||
|
|
||
| #include "src/llm/io_processing/base_output_parser.hpp" | ||
| #include "src/llm/io_processing/partial_json_builder.hpp" | ||
| #include "src/llm/apis/tool_schema_wrapper.hpp" | ||
|
|
||
| namespace ovms { | ||
| class DevstralToolParser : public BaseOutputParser { | ||
| const int64_t argsTokenId; // [ARGS] | ||
| const int64_t botTokenId; // [TOOL_CALLS] | ||
|
|
||
| // in streaming mode we can rely on tags in string format as tokens are not available | ||
| const std::string streamingParsingArgsStartTag = "[ARGS]"; | ||
| const std::string streamingParsingToolCallsStartTag = "[TOOL_CALLS]"; | ||
|
Comment on lines
+31
to
+36
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those tags/tokens are not specific to streaming, so I think we can drop |
||
|
|
||
| enum InternalState { | ||
| AWAITING_START_TAG, | ||
| AWAITING_ARGS_TAG, | ||
| PROCESSING_ARGS | ||
| }; | ||
|
|
||
| InternalState internalState = AWAITING_START_TAG; | ||
| const ToolsSchemas_t& toolSchemas; | ||
| // Index to track the current tool call being processed (-1 means no tool call has been started yet) | ||
| int toolCallIndex = -1; | ||
| std::string streamContent = ""; // content accumulated from stream chunks | ||
| std::string toolName = ""; | ||
| std::optional<rapidjson::Document> sendFullDelta(ToolCall& toolCall); | ||
|
|
||
| public: | ||
| DevstralToolParser() = delete; | ||
| DevstralToolParser(ov::genai::Tokenizer& tokenizer, const ToolsSchemas_t& toolSchemas) : | ||
| BaseOutputParser(tokenizer), | ||
| argsTokenId(tokenizer.encode("[ARGS]", {{"add_special_tokens", false}}).input_ids.data<int64_t>()[0]), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how are we ensured that [ARGS] / [TOOL_CALLS] are single tokens, treated as special, not as string, for example
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those are a special tokens.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that doesnt answer my question
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. devstral parser is setting requiresStreamingWithSpecialTokens() as true |
||
| botTokenId(tokenizer.encode("[TOOL_CALLS]", {{"add_special_tokens", false}}).input_ids.data<int64_t>()[0]), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. validate if input_ids token count is == 1?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed, we could also do that for |
||
| toolSchemas(toolSchemas) {} | ||
|
|
||
| void parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) override; | ||
| std::optional<rapidjson::Document> parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) override; | ||
| const std::vector<std::string>& getParsingStartTags() const override { | ||
| static const std::vector<std::string> toolCallStartTags{streamingParsingToolCallsStartTag}; | ||
| return toolCallStartTags; | ||
| } | ||
| const std::vector<std::string>& getSpecialParsingStartTags() const override { | ||
| static const std::vector<std::string> specialParsingStartTags{}; | ||
| return specialParsingStartTags; | ||
| } | ||
| // Tools calls are expected to be the last part of the content, so we do not specify an end tag. | ||
| const std::string& getParsingEndTag() const override { | ||
| static const std::string toolCallEndTag = ""; | ||
| return toolCallEndTag; | ||
| } | ||
|
|
||
| bool requiresStreamingWithSpecialTokens() const override { | ||
| return true; | ||
| } | ||
| }; | ||
| } // namespace ovms | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does not look like this comment is true for this parser