77"""
88import json
99import uuid
10- from typing import Sequence , Union
10+ from typing import Sequence , Union , List , Optional
1111from langgraph .graph import StateGraph
1212from langgraph .graph import END , START
1313from langchain_core .runnables import RunnableConfig
1414from langchain_core .messages import BaseMessage
15- from jupyterlab_magic_wand .state import AIWorkflowState , ConfigSchema
16- from jupyterlab_magic_wand .agents .lab_commands import (
15+ from jupyterlab_magic_wand .agents .tools import (
1716 update_cell_source ,
1817 show_diff ,
1918 insert_cell_below
2019)
21- from jupyterlab_magic_wand . agents .base import Agent
20+ from .base import Agent
2221
2322from langchain_core .messages import HumanMessage
23+ from pydantic import BaseModel
2424
25- graph = StateGraph (AIWorkflowState , config_schema = ConfigSchema )
25+
26+ class LabCommand (BaseModel ):
27+ name : str
28+ args : dict
29+
30+
31+ class Context (BaseModel ):
32+ cell_id : str
33+ content : dict
34+
35+
36+ class State (BaseModel ):
37+ agent : Optional [str ] = None
38+ input : str
39+ context : Context
40+ messages : list = []
41+ commands : Optional [List [LabCommand ]] = None
42+
43+
44+ graph = StateGraph (State )
2645
2746
2847def get_jupyter_ai_model (jupyter_ai_config ):
2948 lm_provider = jupyter_ai_config .lm_provider
3049 return lm_provider (** jupyter_ai_config .lm_provider_params )
3150
32- def get_cell (cell_id : str , state : AIWorkflowState ) -> dict :
33- content = state ["context" ]["content" ]
51+
52+ def get_cell (cell_id : str , state : State ) -> dict :
53+ content = state .context .content
3454 cells = content ["cells" ]
3555
3656 for cell in cells :
@@ -60,8 +80,8 @@ def sanitize_code(code: str) -> str:
6080 )
6181
6282
63- async def router (state : AIWorkflowState ) -> Sequence [str ]:
64- cell_id = state [ " context" ][ " cell_id" ]
83+ async def router (state : State ) -> Sequence [str ]:
84+ cell_id = state . context . cell_id
6585 current = get_cell (cell_id , state )
6686 if current .get ("cell_type" ) == "markdown" :
6787 return ["route_markdown" ]
@@ -103,29 +123,30 @@ def _cast_ai_response(response: Union[str, BaseMessage]):
103123 raise Exception ("The response type must be 'str' or 'BaseMessage'." )
104124 return response
105125
106- async def route_markdown (state : AIWorkflowState , config : RunnableConfig ) -> dict :
107- llm = get_jupyter_ai_model (config ["configurable" ]["jupyter_ai_config" ])
108- cell_id = state ["context" ]["cell_id" ]
126+
127+ async def route_markdown (state : State , config : RunnableConfig ) -> dict :
128+ llm = get_jupyter_ai_model (config ["configurable" ]["jai_config_manager" ])
129+ cell_id = state .context .cell_id
109130 current = get_cell (cell_id , state )
110131 # Spell check
111132 if current ["source" ].strip () != "" :
112133 response = _cast_ai_response (await llm .ainvoke (input = f"Does the following input look like a prompt to write code (answer 'code' only) or content to be editted (answer 'content' only)?\n Input: { current ['source' ]} " ))
113134 if "code" in response .lower ():
114135 response = _cast_ai_response (await llm .ainvoke (input = f"Write code based on the prompt. Then, update the code to make it more efficient, add code comments, and respond with only the code and comments.\n Input: { current ['source' ]} " ))
115136 response = sanitize_code (response )
116- messages = state .get ( " messages" , []) or []
137+ messages = state .messages
117138 messages .append (response )
118- commands = state [ " commands" ]
139+ commands = state . commands
119140 new_cell_id = str (uuid .uuid4 ())
120141 commands .extend ([
121142 insert_cell_below (cell_id , source = response , type = "code" , new_cell_id = new_cell_id ),
122143 ])
123144 return {"commands" : commands , "messages" : messages }
124145 prompt = SPELLCHECK_MARKDOWN .format (input = current ["source" ])
125146 response = _cast_ai_response (await llm .ainvoke (input = prompt ))
126- messages = state .get ( " messages" , []) or []
147+ messages = state .messages
127148 messages .append (response )
128- commands = state [ " commands" ]
149+ commands = state . commands
129150 commands .extend ([
130151 update_cell_source (cell_id , source = response ),
131152 {
@@ -135,8 +156,8 @@ async def route_markdown(state: AIWorkflowState, config: RunnableConfig) -> dict
135156 ])
136157 return {"commands" : commands , "messages" : messages }
137158
138- content = state [ " context" ][ " content" ]
139- cell_id = state [ " context" ][ " cell_id" ]
159+ content = state . context . content
160+ cell_id = state . context . cell_id
140161 cells = content ["cells" ]
141162
142163
@@ -151,7 +172,7 @@ async def route_markdown(state: AIWorkflowState, config: RunnableConfig) -> dict
151172 response = _cast_ai_response (await llm .ainvoke (input = prompt ))
152173 messages = state .get ("messages" , []) or []
153174 messages .append (response )
154- commands = state [ " commands" ]
175+ commands = state . commands
155176 commands .extend ([
156177 update_cell_source (cell_id , source = response ),
157178 {
@@ -177,9 +198,9 @@ async def route_markdown(state: AIWorkflowState, config: RunnableConfig) -> dict
177198{exception_value}
178199"""
179200
180- async def route_exception (state : AIWorkflowState , config : RunnableConfig ) -> dict :
181- llm = get_jupyter_ai_model (config ["configurable" ]["jupyter_ai_config " ])
182- cell_id = state [ " context" ][ " cell_id" ]
201+ async def route_exception (state : State , config : RunnableConfig ) -> dict :
202+ llm = get_jupyter_ai_model (config ["configurable" ]["jai_config_manager " ])
203+ cell_id = state . context . cell_id
183204 current = get_cell (cell_id , state )
184205 exception = get_exception (current )
185206 prompt = exception_prompt .format (
@@ -189,9 +210,9 @@ async def route_exception(state: AIWorkflowState, config: RunnableConfig) -> dic
189210 )
190211 response = _cast_ai_response (await llm .ainvoke (input = prompt ))
191212 response = sanitize_code (response )
192- messages = state .get ( " messages" , []) or []
213+ messages = state .messages
193214 messages .append (response )
194- commands = state [ " commands" ]
215+ commands = state . commands
195216 commands .extend ([
196217 update_cell_source (cell_id , source = response ),
197218 show_diff (cell_id , current ["source" ], response ),
@@ -215,7 +236,7 @@ async def route_exception(state: AIWorkflowState, config: RunnableConfig) -> dic
215236"""
216237
217238def prompt_new_cell_using_context (cell_id , state ):
218- content = state [ " context" ][ " content" ]
239+ content = state . context . content
219240 cells = content ["cells" ]
220241
221242 for i , cell in enumerate (cells ):
@@ -236,20 +257,20 @@ def prompt_new_cell_using_context(cell_id, state):
236257 return prompt
237258
238259
239- async def route_code (state : AIWorkflowState , config : RunnableConfig ):
240- llm = get_jupyter_ai_model (config ["configurable" ]["jupyter_ai_config " ])
260+ async def route_code (state : State , config : RunnableConfig ):
261+ llm = get_jupyter_ai_model (config ["configurable" ]["jai_config_manager " ])
241262
242- cell_id = state [ " context" ][ " cell_id" ]
263+ cell_id = state . context . cell_id
243264 current = get_cell (cell_id , state )
244265 source = current ["source" ]
245266 source = source .strip ()
246267 if source :
247268 prompt = IMPROVE_PROMPT .format (code = source )
248269 response = _cast_ai_response (await llm .ainvoke (prompt , stream = False ))
249270 response = sanitize_code (response )
250- messages = state .get ( " messages" , []) or []
271+ messages = state .messages
251272 messages .append (response )
252- commands = state [ " commands" ]
273+ commands = state . commands
253274 commands .extend ([
254275 update_cell_source (cell_id , source = response ),
255276 show_diff (cell_id , current ["source" ], response ),
@@ -259,9 +280,9 @@ async def route_code(state: AIWorkflowState, config: RunnableConfig):
259280 prompt = prompt_new_cell_using_context (cell_id , state )
260281 response = _cast_ai_response (await llm .ainvoke (input = prompt , stream = False ))
261282 response = sanitize_code (response )
262- messages = state .get ( " messages" , []) or []
283+ messages = state .messages
263284 messages .append (response )
264- commands = state [ " commands" ]
285+ commands = state . commands
265286 commands .extend ([
266287 update_cell_source (cell_id , source = response ),
267288 ])
@@ -281,9 +302,11 @@ async def route_code(state: AIWorkflowState, config: RunnableConfig):
281302
282303workflow = graph .compile ()
283304
305+
284306agent = Agent (
285307 name = "Magic Button Agent" ,
286308 description = "Magic Button Agent" ,
287309 workflow = workflow ,
288- version = "0.0.1"
310+ version = "0.0.1" ,
311+ state = State
289312)
0 commit comments