From 74110bcd547cd9bcaf4380d5ca5f26375c23f443 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 10 Jan 2025 00:22:06 +0530 Subject: [PATCH 01/15] fp8 dynamic activation scaling (disable static scaling) --- server/lorax_server/layers/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/layers/fp8.py b/server/lorax_server/layers/fp8.py index f03d2974a..435ccecd9 100644 --- a/server/lorax_server/layers/fp8.py +++ b/server/lorax_server/layers/fp8.py @@ -43,7 +43,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input=input, qweight=self.qweight, weight_scale=self.weight_scale, - input_scale=self.input_scale, + input_scale=None, qbias=self.qbias, ) From b38bfb0c306a7e08946321c9339b90a570540d43 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Sat, 11 Jan 2025 00:45:53 +0530 Subject: [PATCH 02/15] allow channelwise scale factors --- server/lorax_server/layers/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/layers/fp8.py b/server/lorax_server/layers/fp8.py index 435ccecd9..1641eb42a 100644 --- a/server/lorax_server/layers/fp8.py +++ b/server/lorax_server/layers/fp8.py @@ -14,7 +14,7 @@ def apply_fp8_linear( input_scale_ub: Optional[torch.Tensor] = None, qbias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=False) + qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=True) output = ops.cutlass_scaled_mm( qinput, qweight, out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale, bias=qbias From 0882883f8bdb00b8cd90c2e639382827196d907a Mon Sep 17 00:00:00 2001 From: Magdy Saleh <17618143+magdyksaleh@users.noreply.github.com> Date: Fri, 17 Jan 2025 22:52:19 -0500 Subject: [PATCH 03/15] fix block size in health check (#742) --- router/src/health.rs | 35 +++++++++++++++++++++++++---------- rust-toolchain.toml | 4 ++-- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/router/src/health.rs b/router/src/health.rs index 5ca8e8de8..cfda0b85c 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,6 +1,6 @@ use lorax_client::{ - Batch, NextTokenChooserParameters, Request, ShardInfo, ShardedClient, - StoppingCriteriaParameters, + input_chunk, Batch, InputChunk, NextTokenChooserParameters, Request, ShardInfo, ShardedClient, + StoppingCriteriaParameters, TokenizedInputs, }; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -40,7 +40,12 @@ impl Health { let generation_liveness_request = Request { id: LIVENESS_ID, inputs: "liveness".to_string(), - tokenized_inputs: None, + tokenized_inputs: Some(TokenizedInputs { + ids: vec![75], + input_chunks: vec![InputChunk { + chunk: Some(input_chunk::Chunk::Text("liveness".to_string())), + }], + }), truncate: 10, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { @@ -66,7 +71,7 @@ impl Health { adapter_index: 0, // Block 0 is reserved for health checks blocks: vec![0], - slots: (0..16).collect(), + slots: (0..self.shard_info.block_size).collect(), cache_len: 0, chunk_len: None, }; @@ -84,15 +89,20 @@ impl Health { pub(crate) async fn check_classification(&mut self) -> bool { let classify_request = Request { id: LIVENESS_ID, - inputs: "San Francisco".to_string(), - tokenized_inputs: None, + inputs: "liveness".to_string(), + tokenized_inputs: Some(TokenizedInputs { + ids: vec![75], + input_chunks: vec![InputChunk { + chunk: Some(input_chunk::Chunk::Text("liveness".to_string())), + }], + }), truncate: 10, prefill_logprobs: false, parameters: None, stopping_parameters: None, adapter_index: 0, blocks: vec![0], - slots: (0..16).collect(), + slots: (0..self.shard_info.block_size).collect(), cache_len: 0, chunk_len: None, }; @@ -109,15 +119,20 @@ impl Health { pub(crate) async fn check_embeddings(&mut self) -> bool { let embed_request = Request { id: LIVENESS_ID, - inputs: "San Francisco".to_string(), - tokenized_inputs: None, + inputs: "liveness".to_string(), + tokenized_inputs: Some(TokenizedInputs { + ids: vec![75], + input_chunks: vec![InputChunk { + chunk: Some(input_chunk::Chunk::Text("liveness".to_string())), + }], + }), truncate: 10, prefill_logprobs: false, parameters: None, stopping_parameters: None, adapter_index: 0, blocks: vec![0], - slots: (0..16).collect(), + slots: (0..self.shard_info.block_size).collect(), cache_len: 0, chunk_len: None, }; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index b6ffc9d2c..80afd2d32 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.79.0" -components = ["rustfmt", "clippy"] \ No newline at end of file +channel = "1.83.0" +components = ["rustfmt", "clippy"] From 7416ea612fcbe1fbba3b744f933fe4caa72ea069 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 20 Jan 2025 09:04:24 -0800 Subject: [PATCH 04/15] Remove graamar constraint on tool calling support --- router/src/infer.rs | 6 +++-- router/src/lib.rs | 3 ++- router/src/tool_grammar.rs | 47 +++++++++++++++++++------------------- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 703dacd46..83a2d4e83 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -112,12 +112,14 @@ impl ChatTemplateRenderer { // if not, we need to append the tools to the last message let text = if self.use_default_tool_template { match serde_json::to_string(&tools) { - Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + // Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + Ok(tools_str) => format!("\n{}\n{}", tools_str, tool_prompt), Err(e) => return Err(InferError::ToolError(e.to_string())), } } else { // if the `tools` variable is used in the template, we just append the tool_prompt - format!("\n---\n{}", tool_prompt) + // format!("\n---\n{}", tool_prompt) + format!("\n{}", tool_prompt) }; if let Some(last_message) = messages.last_mut() { if let Some(content) = &mut last_message.content { diff --git a/router/src/lib.rs b/router/src/lib.rs index c3cf2cedc..0b14a91c2 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -858,7 +858,8 @@ impl ChatCompletionRequest { } pub fn default_tool_prompt() -> String { - "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() + // "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() + "".to_string() } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] diff --git a/router/src/tool_grammar.rs b/router/src/tool_grammar.rs index 6a1f604ba..de189ca84 100644 --- a/router/src/tool_grammar.rs +++ b/router/src/tool_grammar.rs @@ -29,27 +29,27 @@ impl ToolGrammar { let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - let mut tools = tools.clone(); - - // add the no_tool function to the tools - let no_tool = Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "no_tool".to_string(), - description: Some("Open ened response with no specific tool selected".to_string()), - parameters: json!({ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The response content", - } - }, - "required": ["content"] - }), - }, - }; - tools.push(no_tool); + // let mut tools = tools.clone(); + + // // add the no_tool function to the tools + // let no_tool = Tool { + // r#type: "function".to_string(), + // function: FunctionDefinition { + // name: "no_tool".to_string(), + // description: Some("Open ened response with no specific tool selected".to_string()), + // parameters: json!({ + // "type": "object", + // "properties": { + // "content": { + // "type": "string", + // "description": "The response content", + // } + // }, + // "required": ["content"] + // }), + // }, + // }; + // tools.push(no_tool); // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { @@ -106,7 +106,7 @@ impl ToolGrammar { }) .collect(); - let tool_schema = JsonSchemaTool { + let _tool_schema = JsonSchemaTool { functions_map: FunctionsMap { functions }, properties: Properties { function: tools_to_use @@ -118,6 +118,7 @@ impl ToolGrammar { }, }; - Ok((tools, Some(tool_schema))) + // Ok((tools, Some(tool_schema))) + Ok((tools, None)) } } From f412568238c7dcd4b5f55589961399a011191040 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 20 Jan 2025 09:09:25 -0800 Subject: [PATCH 05/15] Revert rust toolchain changes --- rust-toolchain.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 80afd2d32..874939176 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.83.0" +channel = "1.79.0" components = ["rustfmt", "clippy"] From cd90247d72127f9db1e784e0a90b7c7b16f6da57 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 20 Jan 2025 09:14:29 -0800 Subject: [PATCH 06/15] Rust 1.83.0 --- .github/workflows/router_tests.yaml | 2 +- rust-toolchain.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/router_tests.yaml b/.github/workflows/router_tests.yaml index 6eb95f130..404b481d6 100644 --- a/.github/workflows/router_tests.yaml +++ b/.github/workflows/router_tests.yaml @@ -29,7 +29,7 @@ jobs: - name: Install Rust uses: actions-rs/toolchain@v1 with: - toolchain: 1.79.0 + toolchain: 1.83.0 override: true components: rustfmt, clippy - name: Install Protoc diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 874939176..80afd2d32 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.79.0" +channel = "1.83.0" components = ["rustfmt", "clippy"] From 5f251d9516d45426c7c2066e7d0a11a8dc0b4650 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 20 Jan 2025 09:20:50 -0800 Subject: [PATCH 07/15] Use 1.83 --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index eccefae58..0988daf58 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.83 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse From 50563ff89e14be0726bf2cf4cad05aa2475ad3b5 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 21 Jan 2025 16:47:47 -0500 Subject: [PATCH 08/15] this kinda works --- router/src/server.rs | 158 ++++++++++++++++++++++++++++--------------- 1 file changed, 105 insertions(+), 53 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 3eb24521d..3129101fa 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -33,6 +33,7 @@ use futures::Stream; use lorax_client::{ShardInfo, ShardedClient}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use once_cell::sync::OnceCell; +use regex::Regex; use reqwest_middleware::ClientBuilder; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use serde::{Deserialize, Serialize}; @@ -210,6 +211,104 @@ async fn completions_v1( } } +fn parse_json_tool_call( + gen_text_value: Value, +) -> Result<(Option>, Option), InferError> { + let function = gen_text_value.get("function").ok_or(InferError::ToolError( + "No function found in generated text".to_string(), + ))?; + + let name = function + .get("_name") + .and_then(Value::as_str) + .ok_or(InferError::ToolError( + "No _name found in generated text".to_string(), + ))? + .to_string(); + + let mut arguments = function.clone(); + if let Value::Object(ref mut props) = arguments { + props.remove("_name"); + } + match name.as_str() { + "no_tool" => { + // parse the content message + let content_message = arguments + .get("content") + .and_then(Value::as_str) + .ok_or_else(|| { + InferError::ToolError("No `content` found in generated text".to_string()) + })? + .to_string(); + Ok((None, Some(content_message))) + } + _ => { + let arguments = serde_json::to_string(&arguments).map_err(|e| { + InferError::ToolError(format!("Failed to serialize arguments: {}", e)) + })?; + let tool_calls = vec![ToolCall { + id: "0".to_string(), + r#type: "function".to_string(), + function: ReturnFunctionDefinition { + description: None, + name, + arguments, + }, + }]; + Ok((Some(tool_calls), None)) + } + } +} + +fn parse_xml_tool_call(gen: &str) -> Result<(Option>, Option), InferError> { + let tool_call_regex = Regex::new(r"(?s)(.*?)|(.*)") + .map_err(|e| InferError::ToolError(format!("Failed to create tool call regex: {}", e)))?; + // Check for tool call matches + if let Some(captures) = tool_call_regex.captures(gen) { + // Check for complete tool call (first capture group) + let json_content = if let Some(complete_match) = captures.get(1) { + complete_match.as_str() + } + // Check for incomplete tool call (second capture group) + else if let Some(incomplete_match) = captures.get(2) { + incomplete_match.as_str() + } else { + return Ok((None, Some(gen.to_string()))); + }; + + // Parse the JSON content + let parsed_content: serde_json::Value = + serde_json::from_str(json_content.trim()).map_err(|e| { + InferError::ToolError(format!("Failed to parse tool call JSON content: {}", e)) + })?; + + // Extract name and arguments + let name = parsed_content["name"] + .as_str() + .ok_or_else(|| InferError::ToolError("Missing 'name' field in tool call".to_string()))? + .to_string(); + + let arguments = serde_json::to_string(&parsed_content["arguments"]) + .map_err(|e| InferError::ToolError(format!("Failed to serialize arguments: {}", e)))?; + + // Create tool call with the extracted content + let tool_calls = vec![ToolCall { + id: "0".to_string(), + r#type: "function".to_string(), + function: ReturnFunctionDefinition { + description: None, + name, + arguments, + }, + }]; + + Ok((Some(tool_calls), None)) + } else { + // If no tool call tags are found, return the original text + Ok((None, Some(gen.to_string()))) + } +} + /// OpenAI compatible chat completions endpoint #[utoipa::path( post, @@ -319,58 +418,10 @@ async fn chat_completions_v1( let mut choice_content = vec![]; for (_, gen) in generations.iter().enumerate() { let (tool_calls, output) = if using_tools { - let gen_text_value: Value = serde_json::from_str(&gen).map_err(|e| { - InferError::ToolError(format!( - "Failed to parse generated text: {} {:?}", - e, gen - )) - })?; - let function = gen_text_value.get("function").ok_or(InferError::ToolError( - "No function found in generated text".to_string(), - ))?; - - let name = function - .get("_name") - .and_then(Value::as_str) - .ok_or(InferError::ToolError( - "No _name found in generated text".to_string(), - ))? - .to_string(); - - let mut arguments = function.clone(); - if let Value::Object(ref mut props) = arguments { - props.remove("_name"); - } - match name.as_str() { - "no_tool" => { - // parse the content message - let content_message = arguments - .get("content") - .and_then(Value::as_str) - .ok_or_else(|| { - InferError::ToolError( - "No `content` found in generated text".to_string(), - ) - })? - .to_string(); - (None, Some(content_message)) - } - _ => { - let arguments = serde_json::to_string(&arguments).map_err(|e| { - InferError::ToolError(format!("Failed to serialize arguments: {}", e)) - })?; - let tool_calls = vec![ToolCall { - id: "0".to_string(), - r#type: "function".to_string(), - function: ReturnFunctionDefinition { - description: None, - name, - arguments, - }, - }]; - (Some(tool_calls), None) - } - } + match serde_json::from_str::(gen) { + Ok(gen_text_value) => parse_json_tool_call(gen_text_value), + Err(_) => parse_xml_tool_call(gen), + }? } else { (None, Some(gen.clone())) }; @@ -435,7 +486,8 @@ pub(crate) fn prepare_chat_input( messages, Some((updated_tools, tool_prompt.into())), )?; - return Ok((inputs, grammar, tool_schema.is_some())); + // return Ok((inputs, grammar, tool_schema.is_some())); + return Ok((inputs, grammar, true)); } // if no response_format or tools are set simply apply the chat template to generate inputs From a1cf8cb64aa278c68816f389911db053d28df90c Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 21 Jan 2025 16:47:47 -0500 Subject: [PATCH 09/15] Add XML tool call parsing by looking tag --- router/src/server.rs | 158 ++++++++++++++++++++++++++++--------------- 1 file changed, 105 insertions(+), 53 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 3eb24521d..3129101fa 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -33,6 +33,7 @@ use futures::Stream; use lorax_client::{ShardInfo, ShardedClient}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use once_cell::sync::OnceCell; +use regex::Regex; use reqwest_middleware::ClientBuilder; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use serde::{Deserialize, Serialize}; @@ -210,6 +211,104 @@ async fn completions_v1( } } +fn parse_json_tool_call( + gen_text_value: Value, +) -> Result<(Option>, Option), InferError> { + let function = gen_text_value.get("function").ok_or(InferError::ToolError( + "No function found in generated text".to_string(), + ))?; + + let name = function + .get("_name") + .and_then(Value::as_str) + .ok_or(InferError::ToolError( + "No _name found in generated text".to_string(), + ))? + .to_string(); + + let mut arguments = function.clone(); + if let Value::Object(ref mut props) = arguments { + props.remove("_name"); + } + match name.as_str() { + "no_tool" => { + // parse the content message + let content_message = arguments + .get("content") + .and_then(Value::as_str) + .ok_or_else(|| { + InferError::ToolError("No `content` found in generated text".to_string()) + })? + .to_string(); + Ok((None, Some(content_message))) + } + _ => { + let arguments = serde_json::to_string(&arguments).map_err(|e| { + InferError::ToolError(format!("Failed to serialize arguments: {}", e)) + })?; + let tool_calls = vec![ToolCall { + id: "0".to_string(), + r#type: "function".to_string(), + function: ReturnFunctionDefinition { + description: None, + name, + arguments, + }, + }]; + Ok((Some(tool_calls), None)) + } + } +} + +fn parse_xml_tool_call(gen: &str) -> Result<(Option>, Option), InferError> { + let tool_call_regex = Regex::new(r"(?s)(.*?)|(.*)") + .map_err(|e| InferError::ToolError(format!("Failed to create tool call regex: {}", e)))?; + // Check for tool call matches + if let Some(captures) = tool_call_regex.captures(gen) { + // Check for complete tool call (first capture group) + let json_content = if let Some(complete_match) = captures.get(1) { + complete_match.as_str() + } + // Check for incomplete tool call (second capture group) + else if let Some(incomplete_match) = captures.get(2) { + incomplete_match.as_str() + } else { + return Ok((None, Some(gen.to_string()))); + }; + + // Parse the JSON content + let parsed_content: serde_json::Value = + serde_json::from_str(json_content.trim()).map_err(|e| { + InferError::ToolError(format!("Failed to parse tool call JSON content: {}", e)) + })?; + + // Extract name and arguments + let name = parsed_content["name"] + .as_str() + .ok_or_else(|| InferError::ToolError("Missing 'name' field in tool call".to_string()))? + .to_string(); + + let arguments = serde_json::to_string(&parsed_content["arguments"]) + .map_err(|e| InferError::ToolError(format!("Failed to serialize arguments: {}", e)))?; + + // Create tool call with the extracted content + let tool_calls = vec![ToolCall { + id: "0".to_string(), + r#type: "function".to_string(), + function: ReturnFunctionDefinition { + description: None, + name, + arguments, + }, + }]; + + Ok((Some(tool_calls), None)) + } else { + // If no tool call tags are found, return the original text + Ok((None, Some(gen.to_string()))) + } +} + /// OpenAI compatible chat completions endpoint #[utoipa::path( post, @@ -319,58 +418,10 @@ async fn chat_completions_v1( let mut choice_content = vec![]; for (_, gen) in generations.iter().enumerate() { let (tool_calls, output) = if using_tools { - let gen_text_value: Value = serde_json::from_str(&gen).map_err(|e| { - InferError::ToolError(format!( - "Failed to parse generated text: {} {:?}", - e, gen - )) - })?; - let function = gen_text_value.get("function").ok_or(InferError::ToolError( - "No function found in generated text".to_string(), - ))?; - - let name = function - .get("_name") - .and_then(Value::as_str) - .ok_or(InferError::ToolError( - "No _name found in generated text".to_string(), - ))? - .to_string(); - - let mut arguments = function.clone(); - if let Value::Object(ref mut props) = arguments { - props.remove("_name"); - } - match name.as_str() { - "no_tool" => { - // parse the content message - let content_message = arguments - .get("content") - .and_then(Value::as_str) - .ok_or_else(|| { - InferError::ToolError( - "No `content` found in generated text".to_string(), - ) - })? - .to_string(); - (None, Some(content_message)) - } - _ => { - let arguments = serde_json::to_string(&arguments).map_err(|e| { - InferError::ToolError(format!("Failed to serialize arguments: {}", e)) - })?; - let tool_calls = vec![ToolCall { - id: "0".to_string(), - r#type: "function".to_string(), - function: ReturnFunctionDefinition { - description: None, - name, - arguments, - }, - }]; - (Some(tool_calls), None) - } - } + match serde_json::from_str::(gen) { + Ok(gen_text_value) => parse_json_tool_call(gen_text_value), + Err(_) => parse_xml_tool_call(gen), + }? } else { (None, Some(gen.clone())) }; @@ -435,7 +486,8 @@ pub(crate) fn prepare_chat_input( messages, Some((updated_tools, tool_prompt.into())), )?; - return Ok((inputs, grammar, tool_schema.is_some())); + // return Ok((inputs, grammar, tool_schema.is_some())); + return Ok((inputs, grammar, true)); } // if no response_format or tools are set simply apply the chat template to generate inputs From 725b174e6548f81b83245abb7c086cbb40fa3421 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 21 Jan 2025 17:09:21 -0500 Subject: [PATCH 10/15] fix warning --- router/src/tool_grammar.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/router/src/tool_grammar.rs b/router/src/tool_grammar.rs index de189ca84..2ecadda0e 100644 --- a/router/src/tool_grammar.rs +++ b/router/src/tool_grammar.rs @@ -1,8 +1,5 @@ use crate::infer::InferError; -use crate::{ - FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, - ToolType, -}; +use crate::{FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, ToolType}; use serde_json::{json, Map, Value}; use std::collections::HashMap; From 417b20e522aecf1031ea66b566a48650d0e94e72 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Wed, 22 Jan 2025 13:59:18 -0500 Subject: [PATCH 11/15] Catch parsing errors and return the generated model output --- router/src/server.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 3129101fa..d5b24ca9d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -418,10 +418,15 @@ async fn chat_completions_v1( let mut choice_content = vec![]; for (_, gen) in generations.iter().enumerate() { let (tool_calls, output) = if using_tools { - match serde_json::from_str::(gen) { + let tool_call_result = match serde_json::from_str::(gen) { Ok(gen_text_value) => parse_json_tool_call(gen_text_value), Err(_) => parse_xml_tool_call(gen), - }? + }; + match tool_call_result { + Ok((tool_calls, output)) => (tool_calls, output), + // TODO: (magdy) How should we tell the user that the tool call failed? + Err(_) => (None, Some(gen.clone())), + } } else { (None, Some(gen.clone())) }; From dd4452b6c6b10bdcd5b74151cdd115a89a77a311 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Wed, 22 Jan 2025 15:07:52 -0500 Subject: [PATCH 12/15] make dynamic scaling work in general case --- server/lorax_server/layers/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/layers/fp8.py b/server/lorax_server/layers/fp8.py index 1641eb42a..c11f23c15 100644 --- a/server/lorax_server/layers/fp8.py +++ b/server/lorax_server/layers/fp8.py @@ -43,7 +43,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input=input, qweight=self.qweight, weight_scale=self.weight_scale, - input_scale=None, + input_scale=self.input_scale, qbias=self.qbias, ) From 4ac301487c0e119a84076a99f93468baf36b54aa Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Thu, 23 Jan 2025 18:52:21 -0500 Subject: [PATCH 13/15] add tool calls to messages --- router/src/lib.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 0b14a91c2..06650221c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -581,7 +581,7 @@ pub struct Url { } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] -pub(crate) struct ToolCall { +pub struct ToolCall { pub id: String, pub r#type: String, pub function: ReturnFunctionDefinition, @@ -603,6 +603,8 @@ pub struct Message { #[schema(example = "My name is David and I")] pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] name: Option, } @@ -642,6 +644,8 @@ pub struct TextMessage { pub role: String, #[schema(example = "My name is David and I")] pub content: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, } impl From for TextMessage { @@ -660,6 +664,7 @@ impl From for TextMessage { .join(""), None => String::new(), }, + tool_calls: value.tool_calls, } } } From 5a848be48dfd3be43a45a9c605df5bcf372c26af Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Thu, 23 Jan 2025 18:57:21 -0500 Subject: [PATCH 14/15] remove warning --- router/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 06650221c..b046815a6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -957,7 +957,7 @@ pub(crate) struct FunctionDefinition { } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] -pub(crate) struct ReturnFunctionDefinition { +pub struct ReturnFunctionDefinition { #[serde(default)] pub description: Option, pub name: String, From 0253b0fad1780e7a227dbcefa3cda3cbbe3863ce Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Fri, 24 Jan 2025 15:36:34 -0500 Subject: [PATCH 15/15] fix escaping xml tool calls when the input in valid json --- router/src/server.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index d5b24ca9d..941cd354a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -288,8 +288,14 @@ fn parse_xml_tool_call(gen: &str) -> Result<(Option>, Option Result<(Option>, Option