"""
Grok (xAI) Custom Provider for ChainForge

This provider integrates xAI's Grok models with ChainForge.
You'll need to install the openai package and set your xAI API key.

Installation:
pip install openai

Usage:
1. Get your API key from https://console.x.ai/
2. Replace '<YOUR_XAI_API_KEY>' with your actual API key
3. Save this file and import it into ChainForge Custom Providers settings
"""

from chainforge.providers import provider
from typing import List, Dict, Any, Optional
import os

# You can either set your API key here or as an environment variable
API_KEY = '<YOUR_XAI_API_KEY>'  # Replace with your actual API key - keep the '' around your key
# Alternatively, use: API_KEY = os.getenv('XAI_API_KEY')

# JSON schemas for react-jsonschema-form
GROK_SETTINGS_SCHEMA = {
    "settings": {
        "temperature": {
            "type": "number",
            "title": "temperature",
            "description": "Controls randomness in the response. Higher values make output more random.",
            "default": 0.7,
            "minimum": 0.0,
            "maximum": 2.0,
            "multipleOf": 0.01,
        },
        "max_tokens": {
            "type": "integer", 
            "title": "max_tokens",
            "description": "Maximum number of tokens to generate in the response.",
            "default": 1024,
            "minimum": 1,
            "maximum": 8192,
        },
        "top_p": {
            "type": "number",
            "title": "top_p", 
            "description": "Nucleus sampling parameter. Lower values make responses more focused.",
            "default": 0.9,
            "minimum": 0.01,
            "maximum": 1.0,
            "multipleOf": 0.01,
        },
        "stream": {
            "type": "boolean",
            "title": "stream",
            "description": "Whether to stream the response (not supported in ChainForge UI)",
            "default": False,
        }
    },
    "ui": {
        "temperature": {
            "ui:help": "Controls creativity. 0 = deterministic, 2 = very creative",
            "ui:widget": "range"
        },
        "max_tokens": {
            "ui:help": "Maximum response length in tokens",
            "ui:widget": "range"
        },
        "top_p": {
            "ui:help": "Alternative to temperature for controlling randomness",
            "ui:widget": "range"
        },
        "stream": {
            "ui:help": "Streaming not supported in ChainForge interface",
        }
    }
}

def _make_grok_request(prompt: str, model: str, chat_history: Optional[List[Dict[str, str]]] = None, **kwargs) -> str:
    """
    Make a request to the xAI Grok API using OpenAI-compatible format.
    """
    try:
        from openai import OpenAI
    except ImportError:
        raise ImportError("Please install the openai package: pip install openai")
    
    # Initialize the client with xAI endpoint
    client = OpenAI(
        api_key=API_KEY if API_KEY != '<YOUR_XAI_API_KEY>' else os.getenv('XAI_API_KEY'),
        base_url="https://api.x.ai/v1"
    )
    
    # Prepare messages for chat format
    messages = []
    
    # Add chat history if provided
    if chat_history:
        for msg in chat_history:
            messages.append({
                "role": msg.get("role", "user"),
                "content": msg.get("content", "")
            })
    
    # Add the current prompt
    messages.append({
        "role": "user", 
        "content": prompt
    })
    
    # Filter out None values and prepare parameters
    api_params = {
        "model": model,
        "messages": messages,
        "temperature": kwargs.get("temperature", 0.7),
        "max_tokens": kwargs.get("max_tokens", 1024),
        "top_p": kwargs.get("top_p", 0.9),
    }
    
    # Remove any None values
    api_params = {k: v for k, v in api_params.items() if v is not None}
    
    try:
        print(f"Calling xAI Grok model {model} with {len(messages)} messages...")
        response = client.chat.completions.create(**api_params)
        return response.choices[0].message.content
        
    except Exception as e:
        error_msg = f"Error calling xAI Grok API: {str(e)}"
        print(error_msg)
        if "api_key" in str(e).lower():
            error_msg += "\nPlease check your API key is set correctly."
        return f"Error: {error_msg}"

@provider(
    name="Grok (xAI)",
    emoji="🤖",
    models=[
        "grok-4-0709",           # Latest Grok 4 model  
        "grok-4",                # Grok 4 (latest)
        "grok-2-1212",          # Grok 2.1
        "grok-2-vision-1212",   # Grok 2 with vision
        "grok-2",               # Grok 2 (latest)
        "grok-beta",            # Beta model
    ],
    rate_limit="sequential",    # Process requests sequentially to respect rate limits
    settings_schema=GROK_SETTINGS_SCHEMA
)
def GrokCompletion(
    prompt: str, 
    model: str = "grok-4", 
    chat_history: Optional[List[Dict[str, str]]] = None,
    **kwargs
) -> str:
    """
    Grok (xAI) completion provider for ChainForge.
    
    Supports all Grok models including Grok-4, Grok-2, and vision models.
    Compatible with both single prompts and chat conversations.
    
    Args:
        prompt: The input text to send to Grok
        model: The Grok model to use (default: grok-4)
        chat_history: Optional chat history for conversation context
        **kwargs: Additional parameters (temperature, max_tokens, etc.)
        
    Returns:
        Generated response from Grok
    """
    return _make_grok_request(prompt, model, chat_history, **kwargs)