Jean's Blog

一个专注软件测试开发技术的个人博客

0%

LangGraph之节点和可控制性

第一个Langgraph

定义state

定义两种方式:

  1. TypedDict:属于 Python 标准库 typing 模块的一部分,仅提供静态类型检查,运行时不执行验证
  2. Pydantic:第三方库,需要单独安装,提供运行时数据验证和序列化功能

示例代码

1
2
3
4
5
6
7
from langchain_core.messages import AnyMessage
from typing_extensions import TypedDict

# 节点间通讯的消息类型
class State(TypedDict):
messages: list[AnyMessage]
extra_field: int

定义节点

示例代码

1
2
3
4
5
6
7
8
from langchain_core.messages import AIMessage


def node(state: State):
messages = state["messages"]
new_message = AIMessage("你好!我是一个节点")

return {"messages": messages + [new_message], "extra_field": 10}

创建图

  • 包含一个节点
  • 使用state通信

示例代码

1
2
3
4
5
6
from langgraph.graph import StateGraph

graph_builder = StateGraph(State)
graph_builder.add_node(node)
graph_builder.set_entry_point("node") # 设置入口,入口名称为node
graph = graph_builder.compile() # 进行编译

查看节点与图结构(内置的方法)

Mermaid 是一种基于文本的图表和可视化工具,它允许用户通过简单的文本语法来创建复杂的图表和流程图。它特别适合开发者、文档编写者和技术人员在文档、代码库或网页中嵌入可视化内容。

示例代码

1
2
3
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

执行结果展示

image-20250917103406179

调用

示例代码

1
2
3
4
from langchain_core.messages import HumanMessage

result = graph.invoke({"messages": [HumanMessage("你好啊,我是花花!")]})
print(result)

执行结果

1
2
3
{'messages': [HumanMessage(content='你好啊,我是花花!', additional_kwargs={}, response_metadata={}),
AIMessage(content='你好!我是一个节点', additional_kwargs={}, response_metadata={})],
'extra_field': 10}

使用 pretty_print 来格式化显示

输出结果更清晰

示例代码

1
2
3
4
5
from langchain_core.messages import HumanMessage

result = graph.invoke({"messages": [HumanMessage("你好啊,我是花花!")]})
for message in result["messages"]:
message.pretty_print()

执行结果

1
2
3
4
5
6
================================ Human Message =================================

你好啊,我是花花!
================================== Ai Message ==================================

你好!我是一个节点

基本控制:串行控制

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing_extensions import TypedDict
from IPython.display import Image, display
from langgraph.graph import START, END, StateGraph

# 定义节点通信消息类型
class State(TypedDict):
value_1: str
value_2: str

# 定义节点
def step_1(state: State):
return {"value_1": "a"}

def step_2(state: State):
current_value_1 = state["value_1"]
return {"value_1": f"{current_value_1} + b"}

def step_3(state: State):
return {"value_2": 10}


# 创建图
graph_builder = StateGraph(State)
# 设置图中节点
graph_builder.add_node(step_1)
graph_builder.add_node(step_2)
graph_builder.add_node(step_3)
# 设置图中边
graph_builder.add_edge(START, "step_1") # 开始的边
graph_builder.add_edge("step_1", "step_2")
graph_builder.add_edge("step_2", "step_3")
graph_builder.add_edge("step_3", END) # 结束的边

# 图的编译
graph = graph_builder.compile()

# 查看节点与图结构
display(Image(graph.get_graph().draw_mermaid_png()))

# 调用
res = graph.invoke({"value_1": "c"})
print(res)

查看节点与图结构

image-20250917135828786

执行结果:

1
{'value_1': 'a + b', 'value_2': 10}

基本控制:分支控制

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import operator
from typing import Any, Annotated
from typing_extensions import TypedDict
from IPython.display import Image, display
from langgraph.graph import START, END, StateGraph

# 定义节点通信消息类型
# Annotated允许为类型提供额外的元数据,而不影响类型检查时对类型本身的理解
class State(TypedDict):
aggregate: Annotated[list, operator.add]

# 定义节点
def a(state: State):
print(f"添加'A'到{state['aggregate']}")
return {"aggregate": ["A"]}

def b(state: State):
print(f"添加'B'到{state['aggregate']}")
return {"aggregate": ["B"]}

def c(state: State):
print(f"添加'C'到{state['aggregate']}")
return {"aggregate": ["C"]}

def d(state: State):
print(f"添加'D'到{state['aggregate']}")
return {"aggregate": ["D"]}

# 创建图
graph_builder = StateGraph(State)
# 设置图中节点
graph_builder.add_node(a)
graph_builder.add_node(b)
graph_builder.add_node(c)
graph_builder.add_node(d)
# 设置图中边
graph_builder.add_edge(START, "a") # 开始的边
graph_builder.add_edge("a", "b")
graph_builder.add_edge("a", "c")
graph_builder.add_edge("b", "d")
graph_builder.add_edge("c", "d")
graph_builder.add_edge("d", END) # 结束的边

# 图的编译
graph = graph_builder.compile()

# 查看节点与图结构
display(Image(graph.get_graph().draw_mermaid_png()))

# 调用
res = graph.invoke({"aggregate": []}, {"configurable": {"thread_id": "foo"}})
print(res)

查看节点与图结构

image-20250917140002989

执行结果:

1
2
3
4
5
6
添加'A'到[]
添加'B'到['A']
添加'C'到['A']
添加'D'到['A', 'B', 'C']

{'aggregate': ['A', 'B', 'C', 'D']}

基本控制:条件分支与循环

分支条件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import operator
from typing import Annotated, Literal
from typing_extensions import TypedDict
from IPython.display import Image, display
from langgraph.graph import StateGraph, START, END


# 定义节点通信消息类型
class State(TypedDict):
aggregate: Annotated[list, operator.add]

# 定义节点
def a(state: State):
print(f'Node A sees {state["aggregate"]}')
return {"aggregate": ["A"]}


def b(state: State):
print(f'Node B sees {state["aggregate"]}')
return {"aggregate": ["B"]}

# 创建图
builder = StateGraph(State)
builder.add_node(a)
builder.add_node(b)

# 设置边
def route(state: State) -> Literal["b", END]:
if len(state["aggregate"]) < 7:
return "b"
else:
return END


builder.add_edge(START, "a")
builder.add_conditional_edges("a", route) # 条件边
builder.add_edge("b", "a")

# 图的编译
graph = builder.compile()

# 查看节点与图结构
display(Image(graph.get_graph().draw_mermaid_png()))

# 调用
res = graph.invoke({"aggregate": []})
print(res)

查看节点与图结构

image-20250917140318708

执行结果:

1
2
3
4
5
6
7
8
9
Node A sees []
Node B sees ['A']
Node A sees ['A', 'B']
Node B sees ['A', 'B', 'A']
Node A sees ['A', 'B', 'A', 'B']
Node B sees ['A', 'B', 'A', 'B', 'A']
Node A sees ['A', 'B', 'A', 'B', 'A', 'B']

{'aggregate': ['A', 'B', 'A', 'B', 'A', 'B', 'A']}

注意:使用递归限制recursion_limit,防止异常情况下的大量无用调用

1
2
3
4
5
6
from langgraph.errors import GraphRecursionError

try:
graph.invoke({"aggregate": []}, {"recursion_limit": 4})
except GraphRecursionError:
print("Recursion Error") # 递归错误 超出限制

执行结果

1
2
3
4
5
Node A sees []
Node B sees ['A']
Node A sees ['A', 'B']
Node B sees ['A', 'B', 'A']
Recursion Error

循环

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import operator
from typing import Annotated, Literal
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from IPython.display import Image, display

# 定义节点通信消息类型
class State(TypedDict):
aggregate: Annotated[list, operator.add]

# 定义节点
def a(state: State):
print(f'Node A sees {state["aggregate"]}')
return {"aggregate": ["A"]}

def b(state: State):
print(f'Node B sees {state["aggregate"]}')
return {"aggregate": ["B"]}

def c(state: State):
print(f'Node C sees {state["aggregate"]}')
return {"aggregate": ["C"]}

def d(state: State):
print(f'Node D sees {state["aggregate"]}')
return {"aggregate": ["D"]}

# 创建图
builder = StateGraph(State)

# 设置节点
builder.add_node(a)
builder.add_node(b)
builder.add_node(c)
builder.add_node(d)

# 设置边
def route(state: State) -> Literal["b", END]:
if len(state["aggregate"]) < 7:
return "b"
else:
return END

builder.add_edge(START, "a")
builder.add_conditional_edges("a", route)
builder.add_edge("b", "c")
builder.add_edge("b", "d")
builder.add_edge(["c", "d"], "a")

# 图的编译
graph = builder.compile()

# 查看节点与图结构
display(Image(graph.get_graph().draw_mermaid_png()))

# 调用
res = graph.invoke({"aggregate": []})
print(res)

查看节点与图结构

image-20250917140631860

执行结果

1
2
3
4
5
6
7
8
9
10
11
Node A sees []
Node B sees ['A']
Node C sees ['A', 'B']
Node D sees ['A', 'B']
Node A sees ['A', 'B', 'C', 'D']
Node B sees ['A', 'B', 'C', 'D', 'A']
Node C sees ['A', 'B', 'C', 'D', 'A', 'B']
Node D sees ['A', 'B', 'C', 'D', 'A', 'B']
Node A sees ['A', 'B', 'C', 'D', 'A', 'B', 'C', 'D']

{'aggregate': ['A', 'B', 'C', 'D', 'A', 'B', 'C', 'D', 'A']}

精细控制:图的运行时配置

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import operator
from typing import Annotated, Sequence
from typing_extensions import TypedDict

from langchain_deepseek import ChatDeepSeek
from langchain_openai import ChatOpenAI
import os
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.graph import END, StateGraph, START

# 定义模型
model = ChatDeepSeek(
model="deepseek-chat",
temperature=0,
api_key=os.environ.get("DEEPSEEK_API_KEY"),
base_url=os.environ.get("DEEPSEEK_API_BASE"),
)

model1 = ChatOpenAI(
temperature=0,
api_key=os.environ.get("OPENAI_API_KEY"),
base_url=os.environ.get("OPENAI_API_BASE"),
)

# 定义要切换的模型
models = {
"deepseek": model,
"openai": model1,
}


class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]


def _call_model(state: AgentState, config: RunnableConfig):
# 使用LCEL的配置
model_name = config["configurable"].get("model", "deepseek") # 设置默认模型
model = models[model_name]
response = model.invoke(state["messages"])
return {"messages": [response]}


# 创建图
builder = StateGraph(AgentState)
# 定义节点
builder.add_node("model", _call_model)
# 定义边
builder.add_edge(START, "model")
builder.add_edge("model", END)

graph = builder.compile()

# 查看节点与图结构
display(Image(graph.get_graph().draw_mermaid_png()))

没有增加运行时配置的情况下,它会默认调用deepseek

1
2
res = graph.invoke({"messages": [HumanMessage(content="hi 你是谁?")]})
print(res)

执行结果为

1
2
{'messages': [HumanMessage(content='hi 你是谁?', additional_kwargs={}, response_metadata={}),
AIMessage(content='嗨!我是DeepSeek-V3,你的智能助手,由深度求索公司创造。😊 我可以帮你解答问题、聊天、提供建议,甚至协助你处理各种学习和工作上的任务。有什么我可以帮你的吗?', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 49, 'prompt_tokens': 8, 'total_tokens': 57, 'completion_tokens_details': None, 'prompt_tokens_details': {'audio_tokens': None, 'cached_tokens': 0}, 'prompt_cache_hit_tokens': 0, 'prompt_cache_miss_tokens': 8}, 'model_name': 'deepseek-chat', 'system_fingerprint': 'fp_08f168e49b_prod0820_fp8_kvcache', 'id': '228c33b4-21f7-4c47-82bd-7417ca36d81d', 'service_tier': None, 'finish_reason': 'stop', 'logprobs': None}, id='run--ff284eca-736c-4b89-bb63-7e5c07ea1786-0', usage_metadata={'input_tokens': 8, 'output_tokens': 49, 'total_tokens': 57, 'input_token_details': {'cache_read': 0}, 'output_token_details': {}})]}

增加运行时配置,动态切换模型

1
2
3
config = {"configurable": {"model": "openai"}}
res = graph.invoke({"messages": [HumanMessage(content="hi 你是谁?")]}, config=config)
print(res)

执行结果为:

1
2
{'messages': [HumanMessage(content='hi 你是谁?', additional_kwargs={}, response_metadata={}),
AIMessage(content='你好,我是一个人工智能助手。我可以回答你的问题和提供帮助。有什么可以帮到你的吗?', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 14, 'total_tokens': 58, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'id': 'chatcmpl-CGcvpmH2o6wgubw0PLBymye0uPyzN', 'service_tier': 'default', 'finish_reason': 'stop', 'logprobs': None}, id='run--5a5fb090-eae6-4a54-bef8-c877446d5a63-0', usage_metadata={'input_tokens': 14, 'output_tokens': 44, 'total_tokens': 58, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}

精细控制:map-reduce并行执行

给定一个来自用户的一般主题,生成相关主题列表,为每个主题生成一个笑话,并从结果列表中选择最佳笑话。

image-20250917141044009

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import operator
from typing import Annotated
from typing_extensions import TypedDict
from langchain_deepseek import ChatDeepSeek
from langgraph.types import Send
from langgraph.graph import END, StateGraph, START
from IPython.display import Image
from pydantic import BaseModel, Field

# 模型和提示词
# 定义我们将使用的模型和提示词
subjects_prompt = """生成一个逗号分隔的列表,包含2到5个与以下主题相关的例子:{topic}。"""
joke_prompt = """生成一个关于{subject}的笑话"""
best_joke_prompt = """以下是一些关于{topic}的笑话。选出最好的一个!返回最佳笑话的ID。{jokes}"""

# 定义以下三个数据模型
class Subjects(BaseModel):
subjects: list[str]

class Joke(BaseModel):
joke: str

class BestJoke(BaseModel):
id: int = Field(description="最佳笑话的索引,从0开始", ge=0)

# 定义大模型
model = ChatDeepSeek(
model="deepseek-chat",
temperature=0,
api_key=os.environ.get("DEEPSEEK_API_KEY"),
base_url=os.environ.get("DEEPSEEK_API_BASE"),
)

# 图组件:定义构成图的组件


# 这将是主图的整体状态。
# 它将包含一个主题(我们期望用户提供)
# 然后将生成一个主题列表,并为每个主题生成一个笑话
class OverallState(TypedDict):
topic: str
subjects: list
# 注意这里我们使用operator.add
# 这是因为我们想把从各个节点生成的所有笑话
# 合并回一个列表 - 这本质上是"归约"部分
jokes: Annotated[list, operator.add]
best_selected_joke: str


# 这将是我们将"映射"所有主题的节点的状态
# 用于生成笑话
class JokeState(TypedDict):
subject: str

# 定义节点
# 这是我们用来生成笑话主题的函数
def generate_topics(state: OverallState):
prompt = subjects_prompt.format(topic=state["topic"])
# 模型进行结构化输出
response = model.with_structured_output(Subjects).invoke(prompt)
return {"subjects": response.subjects}


# 这里我们根据给定的主题生成笑话
def generate_joke(state: JokeState):
prompt = joke_prompt.format(subject=state["subject"])
# 模型进行结构化输出
response = model.with_structured_output(Joke).invoke(prompt)
return {"jokes": [response.joke]}


# 这里我们定义映射到生成的主题上的逻辑
# 我们将在图中使用这个作为边缘
def continue_to_jokes(state: OverallState):
# 我们将返回一个`Send`对象列表
# 每个`Send`对象包含图中节点的名称
# 以及要发送到该节点的状态
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]


# 这里我们将评判最佳笑话
def best_joke(state: OverallState):
jokes = "\n\n".join(state["jokes"])
prompt = best_joke_prompt.format(topic=state["topic"], jokes=jokes)
response = model.with_structured_output(BestJoke).invoke(prompt)
return {"best_selected_joke": state["jokes"][response.id]}


# 构建图:这里我们将所有内容组合在一起构建我们的图
graph = StateGraph(OverallState)
# 设置节点
graph.add_node("generate_topics", generate_topics)
graph.add_node("generate_joke", generate_joke)
graph.add_node("best_joke", best_joke)
# 设置边
graph.add_edge(START, "generate_topics")
graph.add_conditional_edges("generate_topics", continue_to_jokes, ["generate_joke"]) # 条件边
graph.add_edge("generate_joke", "best_joke")
graph.add_edge("best_joke", END)
app = graph.compile() # 图编译

# 查看节点与图结构
Image(app.get_graph().draw_mermaid_png())

# 调用
for s in app.stream({"topic": "动物"}):
print(s)

查看节点与图结构

image-20250917141132825

执行结果为

1
2
3
{'generate_topics': {'subjects': ['动物']}}
{'generate_joke': {'jokes': ['为什么老虎不喝茶?因为它们喝茶会变成茶虎!']}}
{'best_joke': {'best_selected_joke': '为什么老虎不喝茶?因为它们喝茶会变成茶虎!'}}