Skip to content

Commit 7fa74ea

Browse files
authored
Merge pull request #12 from Zsailer/agent-api
Simplify server extension to make agent development lightweight
2 parents bf362f9 + 973f779 commit 7fa74ea

File tree

17 files changed

+172
-287
lines changed

17 files changed

+172
-287
lines changed
Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
1-
from dataclasses import dataclass
2-
from langgraph.graph.graph import CompiledGraph
1+
import urllib.parse
2+
from typing import Type
3+
from pydantic import BaseModel, ConfigDict
4+
from langgraph.graph.state import CompiledStateGraph
35

46

5-
@dataclass
6-
class Agent:
7+
class Agent(BaseModel):
8+
model_config = ConfigDict(arbitrary_types_allowed=True)
79
name: str
810
description: str
9-
workflow: CompiledGraph
10-
version: str
11+
workflow: CompiledStateGraph
12+
version: str
13+
state: Type[BaseModel]
14+
15+
@property
16+
def state_schema_id(self):
17+
name = urllib.parse.quote_plus(self.name)
18+
return f"https://events.jupyter.org/jupyter-ai/agents/{name}/state"
19+
20+
@property
21+
def state_schema(self):
22+
event_schema = {
23+
"$id": self.state_schema_id,
24+
"version": self.version,
25+
"title": "",
26+
"description": "",
27+
"personal-data": True,
28+
"type": "object",
29+
}
30+
event_schema.update(self.state.model_json_schema())
31+
return event_schema

jupyterlab_magic_wand/agents/magic_agent.py

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,50 @@
77
"""
88
import json
99
import uuid
10-
from typing import Sequence, Union
10+
from typing import Sequence, Union, List, Optional
1111
from langgraph.graph import StateGraph
1212
from langgraph.graph import END, START
1313
from langchain_core.runnables import RunnableConfig
1414
from langchain_core.messages import BaseMessage
15-
from jupyterlab_magic_wand.state import AIWorkflowState, ConfigSchema
16-
from jupyterlab_magic_wand.agents.lab_commands import (
15+
from jupyterlab_magic_wand.agents.tools import (
1716
update_cell_source,
1817
show_diff,
1918
insert_cell_below
2019
)
21-
from jupyterlab_magic_wand.agents.base import Agent
20+
from .base import Agent
2221

2322
from langchain_core.messages import HumanMessage
23+
from pydantic import BaseModel
2424

25-
graph = StateGraph(AIWorkflowState, config_schema=ConfigSchema)
25+
26+
class LabCommand(BaseModel):
27+
name: str
28+
args: dict
29+
30+
31+
class Context(BaseModel):
32+
cell_id: str
33+
content: dict
34+
35+
36+
class State(BaseModel):
37+
agent: Optional[str] = None
38+
input: str
39+
context: Context
40+
messages: list = []
41+
commands: Optional[List[LabCommand]] = None
42+
43+
44+
graph = StateGraph(State)
2645

2746

2847
def get_jupyter_ai_model(jupyter_ai_config):
2948
lm_provider = jupyter_ai_config.lm_provider
3049
return lm_provider(**jupyter_ai_config.lm_provider_params)
3150

32-
def get_cell(cell_id: str, state: AIWorkflowState) -> dict:
33-
content = state["context"]["content"]
51+
52+
def get_cell(cell_id: str, state: State) -> dict:
53+
content = state.context.content
3454
cells = content["cells"]
3555

3656
for cell in cells:
@@ -60,8 +80,8 @@ def sanitize_code(code: str) -> str:
6080
)
6181

6282

63-
async def router(state: AIWorkflowState) -> Sequence[str]:
64-
cell_id = state["context"]["cell_id"]
83+
async def router(state: State) -> Sequence[str]:
84+
cell_id = state.context.cell_id
6585
current = get_cell(cell_id, state)
6686
if current.get("cell_type") == "markdown":
6787
return ["route_markdown"]
@@ -103,29 +123,30 @@ def _cast_ai_response(response: Union[str, BaseMessage]):
103123
raise Exception("The response type must be 'str' or 'BaseMessage'.")
104124
return response
105125

106-
async def route_markdown(state: AIWorkflowState, config: RunnableConfig) -> dict:
107-
llm = get_jupyter_ai_model(config["configurable"]["jupyter_ai_config"])
108-
cell_id = state["context"]["cell_id"]
126+
127+
async def route_markdown(state: State, config: RunnableConfig) -> dict:
128+
llm = get_jupyter_ai_model(config["configurable"]["jai_config_manager"])
129+
cell_id = state.context.cell_id
109130
current = get_cell(cell_id, state)
110131
# Spell check
111132
if current["source"].strip() != "":
112133
response = _cast_ai_response(await llm.ainvoke(input=f"Does the following input look like a prompt to write code (answer 'code' only) or content to be editted (answer 'content' only)?\n Input: {current['source']}"))
113134
if "code" in response.lower():
114135
response = _cast_ai_response(await llm.ainvoke(input=f"Write code based on the prompt. Then, update the code to make it more efficient, add code comments, and respond with only the code and comments.\n Input: {current['source']}"))
115136
response = sanitize_code(response)
116-
messages = state.get("messages", []) or []
137+
messages = state.messages
117138
messages.append(response)
118-
commands = state["commands"]
139+
commands = state.commands
119140
new_cell_id = str(uuid.uuid4())
120141
commands.extend([
121142
insert_cell_below(cell_id, source=response, type="code", new_cell_id=new_cell_id),
122143
])
123144
return {"commands": commands, "messages": messages}
124145
prompt = SPELLCHECK_MARKDOWN.format(input=current["source"])
125146
response = _cast_ai_response(await llm.ainvoke(input=prompt))
126-
messages = state.get("messages", []) or []
147+
messages = state.messages
127148
messages.append(response)
128-
commands = state["commands"]
149+
commands = state.commands
129150
commands.extend([
130151
update_cell_source(cell_id, source=response),
131152
{
@@ -135,8 +156,8 @@ async def route_markdown(state: AIWorkflowState, config: RunnableConfig) -> dict
135156
])
136157
return {"commands": commands, "messages": messages}
137158

138-
content = state["context"]["content"]
139-
cell_id = state["context"]["cell_id"]
159+
content = state.context.content
160+
cell_id = state.context.cell_id
140161
cells = content["cells"]
141162

142163

@@ -151,7 +172,7 @@ async def route_markdown(state: AIWorkflowState, config: RunnableConfig) -> dict
151172
response = _cast_ai_response(await llm.ainvoke(input=prompt))
152173
messages = state.get("messages", []) or []
153174
messages.append(response)
154-
commands = state["commands"]
175+
commands = state.commands
155176
commands.extend([
156177
update_cell_source(cell_id, source=response),
157178
{
@@ -177,9 +198,9 @@ async def route_markdown(state: AIWorkflowState, config: RunnableConfig) -> dict
177198
{exception_value}
178199
"""
179200

180-
async def route_exception(state: AIWorkflowState, config: RunnableConfig) -> dict:
181-
llm = get_jupyter_ai_model(config["configurable"]["jupyter_ai_config"])
182-
cell_id = state["context"]["cell_id"]
201+
async def route_exception(state: State, config: RunnableConfig) -> dict:
202+
llm = get_jupyter_ai_model(config["configurable"]["jai_config_manager"])
203+
cell_id = state.context.cell_id
183204
current = get_cell(cell_id, state)
184205
exception = get_exception(current)
185206
prompt = exception_prompt.format(
@@ -189,9 +210,9 @@ async def route_exception(state: AIWorkflowState, config: RunnableConfig) -> dic
189210
)
190211
response = _cast_ai_response(await llm.ainvoke(input=prompt))
191212
response = sanitize_code(response)
192-
messages = state.get("messages", []) or []
213+
messages = state.messages
193214
messages.append(response)
194-
commands = state["commands"]
215+
commands = state.commands
195216
commands.extend([
196217
update_cell_source(cell_id, source=response),
197218
show_diff(cell_id, current["source"], response),
@@ -215,7 +236,7 @@ async def route_exception(state: AIWorkflowState, config: RunnableConfig) -> dic
215236
"""
216237

217238
def prompt_new_cell_using_context(cell_id, state):
218-
content = state["context"]["content"]
239+
content = state.context.content
219240
cells = content["cells"]
220241

221242
for i, cell in enumerate(cells):
@@ -236,20 +257,20 @@ def prompt_new_cell_using_context(cell_id, state):
236257
return prompt
237258

238259

239-
async def route_code(state: AIWorkflowState, config: RunnableConfig):
240-
llm = get_jupyter_ai_model(config["configurable"]["jupyter_ai_config"])
260+
async def route_code(state: State, config: RunnableConfig):
261+
llm = get_jupyter_ai_model(config["configurable"]["jai_config_manager"])
241262

242-
cell_id = state["context"]["cell_id"]
263+
cell_id = state.context.cell_id
243264
current = get_cell(cell_id, state)
244265
source = current["source"]
245266
source = source.strip()
246267
if source:
247268
prompt = IMPROVE_PROMPT.format(code=source)
248269
response = _cast_ai_response(await llm.ainvoke(prompt, stream=False))
249270
response = sanitize_code(response)
250-
messages = state.get("messages", []) or []
271+
messages = state.messages
251272
messages.append(response)
252-
commands = state["commands"]
273+
commands = state.commands
253274
commands.extend([
254275
update_cell_source(cell_id, source=response),
255276
show_diff(cell_id, current["source"], response),
@@ -259,9 +280,9 @@ async def route_code(state: AIWorkflowState, config: RunnableConfig):
259280
prompt = prompt_new_cell_using_context(cell_id, state)
260281
response = _cast_ai_response(await llm.ainvoke(input=prompt, stream=False))
261282
response = sanitize_code(response)
262-
messages = state.get("messages", []) or []
283+
messages = state.messages
263284
messages.append(response)
264-
commands = state["commands"]
285+
commands = state.commands
265286
commands.extend([
266287
update_cell_source(cell_id, source=response),
267288
])
@@ -281,9 +302,11 @@ async def route_code(state: AIWorkflowState, config: RunnableConfig):
281302

282303
workflow = graph.compile()
283304

305+
284306
agent = Agent(
285307
name = "Magic Button Agent",
286308
description = "Magic Button Agent",
287309
workflow = workflow,
288-
version = "0.0.1"
310+
version = "0.0.1",
311+
state=State
289312
)
File renamed without changes.

jupyterlab_magic_wand/agents/lab_commands/insert_cell_below.py renamed to jupyterlab_magic_wand/agents/tools/insert_cell_below.py

File renamed without changes.
File renamed without changes.

jupyterlab_magic_wand/agents/lab_commands/update_cell_source.py renamed to jupyterlab_magic_wand/agents/tools/update_cell_source.py

File renamed without changes.

jupyterlab_magic_wand/config.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

jupyterlab_magic_wand/events/magic-button.yml

Lines changed: 0 additions & 35 deletions
This file was deleted.

jupyterlab_magic_wand/extension.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,40 @@
22
from traitlets import Instance, Dict, Unicode, default
33
from jupyter_server.extension.application import ExtensionApp
44

5-
from .magic_handler import MagicHandler
65
from .rest_handlers import handlers
76
from importlib_metadata import entry_points
8-
from .config import ConfigManager
7+
from .agents.base import Agent
98

109

11-
feedback_logger = logging.getLogger("jupyterlab_magic_wand_feedback")
12-
1310
class AIMagicExtension(ExtensionApp):
1411
name = "jupyterlab_magic_wand"
1512
handlers = handlers
1613

17-
magic_handler = Instance(MagicHandler, allow_none=True)
18-
ai_config = Instance(ConfigManager, allow_none=True)
1914
agents = Dict(key_trait=Unicode, value_trait=Instance(object))
20-
feedback = Instance(logging.Logger, allow_none=True)
21-
22-
@default('feedback')
23-
def _default_feedback(self):
24-
return feedback_logger
15+
default_agent = Unicode(
16+
default_value="Magic Button Agent",
17+
help=(
18+
"The name of the default agent, if an agent is not "
19+
"explicitly named when a request is made to the server."
20+
)
21+
).tag(config=True)
2522

2623
def initialize_settings(self):
2724
eps = entry_points()
2825
agents_eps = eps.select(group="jupyterlab_magic_wand.agents")
2926
for eps in agents_eps:
3027
try:
31-
agent = eps.load()
28+
agent: Agent = eps.load()
3229
self.agents[agent.name] = agent
33-
self.log.info(f"Successfully loaded workflow: {agent.name}")
30+
import json
31+
print(json.dumps(agent.state_schema, indent=2))
32+
self.serverapp.event_logger.register_event_schema(agent.state_schema)
33+
self.log.info(f"Successfully loaded agent: {agent.name}")
3434
except Exception as err:
3535
self.log.error(err)
3636
self.log.error(f"Unable to load {agent.name}")
37-
38-
self.ai_config = ConfigManager(self.agents)
39-
self.magic_handler = MagicHandler(
40-
event_logger=self.serverapp.event_logger,
41-
config=self.ai_config,
42-
jupyter_ai_config=self.settings["jai_config_manager"]
43-
)
37+
4438
self.settings.update({
45-
"magic_handler": self.magic_handler,
4639
"agents": self.agents,
47-
"ai_config": self.ai_config,
48-
"feedback": self.feedback
40+
"current_agent": self.default_agent
4941
})

0 commit comments

Comments
 (0)