Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions src/providers/openai_codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct OpenAiCodexProvider {
client: Client,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
struct ResponsesRequest {
model: String,
input: Vec<ResponsesInput>,
Expand All @@ -38,13 +38,13 @@ struct ResponsesRequest {
parallel_tool_calls: bool,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
struct ResponsesInput {
role: String,
content: Vec<ResponsesInputContent>,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
struct ResponsesInputContent {
#[serde(rename = "type")]
kind: String,
Expand All @@ -54,12 +54,12 @@ struct ResponsesInputContent {
image_url: Option<String>,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
struct ResponsesTextOptions {
verbosity: String,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
struct ResponsesReasoningOptions {
effort: String,
summary: String,
Expand Down Expand Up @@ -613,7 +613,54 @@ impl OpenAiCodexProvider {
return Err(super::api_error("OpenAI Codex", response).await);
}

decode_responses_body(response).await
// Try to decode streaming response first
match decode_responses_body(response).await {
Ok(text) => Ok(text),
// If streaming fails, retry with non-streaming request
Err(e) => {
tracing::warn!(
error = %e,
"OpenAI Codex streaming failed, retrying with non-streaming request"
);
let mut non_streaming_request = request.clone();
non_streaming_request.stream = false;

let non_streaming_response = self
.client
.post(&self.responses_url)
.header("Authorization", format!("Bearer {bearer_token}"))
.header("OpenAI-Beta", "responses=experimental")
.header("originator", "pi")
.header("Content-Type", "application/json");

let non_streaming_response = if let Some(account_id) = account_id.as_deref() {
non_streaming_response.header("chatgpt-account-id", account_id)
} else {
non_streaming_response
};

let non_streaming_response = if use_gateway_api_key_auth {
if let Some(access_token) = access_token.as_deref() {
non_streaming_response.header("x-openai-access-token", access_token)
} else {
non_streaming_response
}
} else {
non_streaming_response
};

let response = non_streaming_response
.json(&non_streaming_request)
.send()
.await?;

if !response.status().is_success() {
return Err(super::api_error("OpenAI Codex", response).await);
}

decode_responses_body(response).await
}
}
}
}

Expand Down
162 changes: 124 additions & 38 deletions src/tools/git_operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,15 @@ impl GitOperationsTool {
)
}

async fn run_git_command(&self, args: &[&str]) -> anyhow::Result<String> {
async fn run_git_command(
&self,
args: &[&str],
cwd: Option<&std::path::Path>,
) -> anyhow::Result<String> {
let work_dir = cwd.unwrap_or(&self.workspace_dir);
let output = tokio::process::Command::new("git")
.args(args)
.current_dir(&self.workspace_dir)
.current_dir(work_dir)
.output()
.await?;

Expand All @@ -80,9 +85,13 @@ impl GitOperationsTool {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
}

async fn git_status(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
async fn git_status(
&self,
_args: serde_json::Value,
cwd: Option<&std::path::Path>,
) -> anyhow::Result<ToolResult> {
let output = self
.run_git_command(&["status", "--porcelain=2", "--branch"])
.run_git_command(&["status", "--porcelain=2", "--branch"], cwd)
.await?;

// Parse git status output into structured format
Expand Down Expand Up @@ -131,7 +140,11 @@ impl GitOperationsTool {
})
}

async fn git_diff(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
async fn git_diff(
&self,
args: serde_json::Value,
cwd: Option<&std::path::Path>,
) -> anyhow::Result<ToolResult> {
let files = args.get("files").and_then(|v| v.as_str()).unwrap_or(".");
let cached = args
.get("cached")
Expand All @@ -148,7 +161,7 @@ impl GitOperationsTool {
git_args.push("--");
git_args.push(files);

let output = self.run_git_command(&git_args).await?;
let output = self.run_git_command(&git_args, cwd).await?;

// Parse diff into structured hunks
let mut result = serde_json::Map::new();
Expand Down Expand Up @@ -210,18 +223,25 @@ impl GitOperationsTool {
})
}

async fn git_log(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
async fn git_log(
&self,
args: serde_json::Value,
cwd: Option<&std::path::Path>,
) -> anyhow::Result<ToolResult> {
let limit_raw = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10);
let limit = usize::try_from(limit_raw).unwrap_or(usize::MAX).min(1000);
let limit_str = limit.to_string();

let output = self
.run_git_command(&[
"log",
&format!("-{limit_str}"),
"--pretty=format:%H|%an|%ae|%ad|%s",
"--date=iso",
])
.run_git_command(
&[
"log",
&format!("-{limit_str}"),
"--pretty=format:%H|%an|%ae|%ad|%s",
"--date=iso",
],
cwd,
)
.await?;

let mut commits = Vec::new();
Expand All @@ -247,9 +267,13 @@ impl GitOperationsTool {
})
}

async fn git_branch(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
async fn git_branch(
&self,
_args: serde_json::Value,
cwd: Option<&std::path::Path>,
) -> anyhow::Result<ToolResult> {
let output = self
.run_git_command(&["branch", "--format=%(refname:short)|%(HEAD)"])
.run_git_command(&["branch", "--format=%(refname:short)|%(HEAD)"], cwd)
.await?;

let mut branches = Vec::new();
Expand Down Expand Up @@ -287,7 +311,11 @@ impl GitOperationsTool {
}
}

async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
async fn git_commit(
&self,
args: serde_json::Value,
cwd: Option<&std::path::Path>,
) -> anyhow::Result<ToolResult> {
let message = args
.get("message")
.and_then(|v| v.as_str())
Expand All @@ -308,7 +336,7 @@ impl GitOperationsTool {
// Limit message length
let message = Self::truncate_commit_message(&sanitized);

let output = self.run_git_command(&["commit", "-m", &message]).await;
let output = self.run_git_command(&["commit", "-m", &message], cwd).await;

match output {
Ok(_) => Ok(ToolResult {
Expand All @@ -324,7 +352,11 @@ impl GitOperationsTool {
}
}

async fn git_add(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
async fn git_add(
&self,
args: serde_json::Value,
cwd: Option<&std::path::Path>,
) -> anyhow::Result<ToolResult> {
let paths = args
.get("paths")
.and_then(|v| v.as_str())
Expand All @@ -333,7 +365,7 @@ impl GitOperationsTool {
// Validate paths against injection patterns
self.sanitize_git_args(paths)?;

let output = self.run_git_command(&["add", "--", paths]).await;
let output = self.run_git_command(&["add", "--", paths], cwd).await;

match output {
Ok(_) => Ok(ToolResult {
Expand All @@ -349,7 +381,11 @@ impl GitOperationsTool {
}
}

async fn git_checkout(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
async fn git_checkout(
&self,
args: serde_json::Value,
cwd: Option<&std::path::Path>,
) -> anyhow::Result<ToolResult> {
let branch = args
.get("branch")
.and_then(|v| v.as_str())
Expand All @@ -369,7 +405,7 @@ impl GitOperationsTool {
anyhow::bail!("Branch name contains invalid characters");
}

let output = self.run_git_command(&["checkout", branch_name]).await;
let output = self.run_git_command(&["checkout", branch_name], cwd).await;

match output {
Ok(_) => Ok(ToolResult {
Expand All @@ -385,24 +421,28 @@ impl GitOperationsTool {
}
}

async fn git_stash(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
async fn git_stash(
&self,
args: serde_json::Value,
cwd: Option<&std::path::Path>,
) -> anyhow::Result<ToolResult> {
let action = args
.get("action")
.and_then(|v| v.as_str())
.unwrap_or("push");

let output = match action {
"push" | "save" => {
self.run_git_command(&["stash", "push", "-m", "auto-stash"])
self.run_git_command(&["stash", "push", "-m", "auto-stash"], cwd)
.await
}
"pop" => self.run_git_command(&["stash", "pop"]).await,
"list" => self.run_git_command(&["stash", "list"]).await,
"pop" => self.run_git_command(&["stash", "pop"], cwd).await,
"list" => self.run_git_command(&["stash", "list"], cwd).await,
"drop" => {
let index_raw = args.get("index").and_then(|v| v.as_u64()).unwrap_or(0);
let index = i32::try_from(index_raw)
.map_err(|_| anyhow::anyhow!("stash index too large: {index_raw}"))?;
self.run_git_command(&["stash", "drop", &format!("stash@{{{index}}}")])
self.run_git_command(&["stash", "drop", &format!("stash@{{{index}}}")], cwd)
.await
}
_ => anyhow::bail!("Unknown stash action: {action}. Use: push, pop, list, drop"),
Expand Down Expand Up @@ -442,6 +482,10 @@ impl Tool for GitOperationsTool {
"enum": ["status", "diff", "log", "branch", "commit", "add", "checkout", "stash"],
"description": "Git operation to perform"
},
"path": {
"type": "string",
"description": "Working directory for the git operation (relative to workspace or absolute path within workspace)"
},
"message": {
"type": "string",
"description": "Commit message (for 'commit' operation)"
Expand Down Expand Up @@ -492,10 +536,37 @@ impl Tool for GitOperationsTool {
}
};

// Check if we're in a git repository
if !self.workspace_dir.join(".git").exists() {
// Try to find .git in parent directories
let mut current_dir = self.workspace_dir.as_path();
// Extract and resolve the path parameter (working directory for git commands)
let cwd: Option<std::path::PathBuf> =
args.get("path").and_then(|v| v.as_str()).and_then(|p| {
if p.is_empty() {
return None;
}
// Resolve the path - if it's absolute, use it; if relative, join with workspace
let resolved = if p.starts_with('/') {
std::path::PathBuf::from(p)
} else {
self.workspace_dir.join(p)
};
// Canonicalize to resolve any ".." or "." components
resolved.canonicalize().ok().map(|p| {
// Ensure the resolved path is within the workspace
if p.starts_with(&self.workspace_dir) {
p
} else {
// If canonicalization fails to keep it in workspace, use original resolved path
resolved
}
})
});

// Determine the effective working directory for git operations
let effective_cwd: std::path::PathBuf = cwd.unwrap_or_else(|| self.workspace_dir.clone());

// Check if we're in a git repository (check effective_cwd, not just workspace root)
if !effective_cwd.join(".git").exists() {
// Try to find .git in parent directories of effective_cwd
let mut current_dir = effective_cwd.as_path();
let mut found_git = false;
while current_dir.parent().is_some() {
if current_dir.join(".git").exists() {
Expand All @@ -505,6 +576,18 @@ impl Tool for GitOperationsTool {
current_dir = current_dir.parent().unwrap();
}

// Also check workspace root if not found in effective_cwd parents
if !found_git && !self.workspace_dir.join(".git").exists() {
let mut current_dir = self.workspace_dir.as_path();
while current_dir.parent().is_some() {
if current_dir.join(".git").exists() {
found_git = true;
break;
}
current_dir = current_dir.parent().unwrap();
}
}

if !found_git {
return Ok(ToolResult {
success: false,
Expand Down Expand Up @@ -549,14 +632,17 @@ impl Tool for GitOperationsTool {

// Execute the requested operation
match operation {
"status" => self.git_status(args).await,
"diff" => self.git_diff(args).await,
"log" => self.git_log(args).await,
"branch" => self.git_branch(args).await,
"commit" => self.git_commit(args).await,
"add" => self.git_add(args).await,
"checkout" => self.git_checkout(args).await,
"stash" => self.git_stash(args).await,
"status" => self.git_status(args, effective_cwd.as_path().into()).await,
"diff" => self.git_diff(args, effective_cwd.as_path().into()).await,
"log" => self.git_log(args, effective_cwd.as_path().into()).await,
"branch" => self.git_branch(args, effective_cwd.as_path().into()).await,
"commit" => self.git_commit(args, effective_cwd.as_path().into()).await,
"add" => self.git_add(args, effective_cwd.as_path().into()).await,
"checkout" => {
self.git_checkout(args, effective_cwd.as_path().into())
.await
}
"stash" => self.git_stash(args, effective_cwd.as_path().into()).await,
_ => Ok(ToolResult {
success: false,
output: String::new(),
Expand Down
Loading