Jean's Blog

一个专注软件测试开发技术的个人博客

0%

LangChain通过类继承实现自定义中间件

介绍

基于类的中间件实现是一种更灵活、功能更强的钩子开发方式,特别适合需要多个钩子、状态管理或复杂配置的场景。它通过继承 AgentMiddleware 基类,在一个类中统一实现多个钩子方法,解决了函数式装饰器难以管理多钩子和共享状态的问题。

类继承方式的工作原理

  1. 继承基类:自定义中间件类需要继承框架提供的 AgentMiddleware 基类,这是实现类式中间件的基础。
  2. 实现钩子方法:在子类中重写 / 实现特定的钩子方法(如 before_modelafter_model),这些方法会在对应生命周期阶段被自动调用。
  3. 多钩子整合:可以在同一个类中同时实现 before_modelafter_model 等多个钩子,让相关逻辑内聚在一个类中。
  4. 状态管理:通过类的实例变量(如 self.counterself.cache)来维护跨请求的状态,实现计数、缓存、会话级别的数据共享。

与函数式装饰器的核心区别 ✨

维度 函数式装饰器(@before_model/@after_model 类继承方式(AgentMiddleware
适用场景 简单、独立的单钩子逻辑 复杂、多钩子、需要状态管理的场景
钩子数量 一个函数对应一个钩子 一个类可实现多个钩子(before/after 等)
状态管理 依赖外部变量或闭包,管理不便 直接使用类实例变量,状态管理更清晰
代码组织 钩子逻辑分散在不同函数中 相关逻辑内聚在一个类中,便于维护
配置能力 装饰器参数简单配置 可通过 __init__ 实现复杂初始化和配置

核心价值总结 💡

  • 高内聚:将输入预处理、输出后处理等相关逻辑封装在同一个类中,代码结构更清晰。
  • 状态可控:通过实例变量轻松实现跨请求的状态共享(如统计调用次数、缓存结果)。
  • 扩展性强:便于继承和复用,可通过子类扩展功能,适合构建复杂的中间件系统。
  • 配置灵活:支持在初始化时传入参数,实现高度可配置的中间件。

AgentMiddleware核心方法详解

AgentMiddleware导包:from langchain.agents.middleware import AgentMiddleware

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class AgentMiddleware(Generic[StateT, ContextT]):
"""Base middleware class for an agent.

Subclass this and implement any of the defined methods to customize agent behavior
between steps in the main agent loop.
"""

state_schema: type[StateT] = cast("type[StateT]", AgentState)
"""The schema for state passed to the middleware nodes."""

tools: list[BaseTool]
"""Additional tools registered by the middleware."""

@property
def name(self) -> str:
"""The name of the middleware instance.

Defaults to the class name, but can be overridden for custom naming.
"""
return self.__class__.__name__

def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the agent execution starts.

Async version is `abefore_agent`
"""

async def abefore_agent(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run before the agent execution starts."""

def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the model is called.

Async version is `abefore_model`
"""

async def abefore_model(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run before the model is called."""

def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the model is called.

Async version is `aafter_model`
"""

async def aafter_model(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run after the model is called."""

def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Intercept and control model execution via handler callback.

Async version is `awrap_model_call`

The handler callback executes the model request and returns a `ModelResponse`.
Middleware can call the handler multiple times for retry logic, skip calling
it to short-circuit, or modify the request/response. Multiple middleware
compose with first in list as outermost layer.

Args:
request: Model request to execute (includes state and runtime).
handler: Callback that executes the model request and returns
`ModelResponse`.

Call this to execute the model.

Can be called multiple times for retry logic.

Can skip calling it to short-circuit.

Returns:
`ModelCallResult`
"""
msg = (
"Synchronous implementation of wrap_model_call is not available. "
"You are likely encountering this error because you defined only the async version "
"(awrap_model_call) and invoked your agent in a synchronous context "
"(e.g., using `stream()` or `invoke()`). "
"To resolve this, either: "
"(1) subclass AgentMiddleware and implement the synchronous wrap_model_call method, "
"(2) use the @wrap_model_call decorator on a standalone sync function, or "
"(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
)
raise NotImplementedError(msg)

async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Intercept and control async model execution via handler callback.

The handler callback executes the model request and returns a `ModelResponse`.

Middleware can call the handler multiple times for retry logic, skip calling
it to short-circuit, or modify the request/response. Multiple middleware
compose with first in list as outermost layer.

Args:
request: Model request to execute (includes state and runtime).
handler: Async callback that executes the model request and returns
`ModelResponse`.

Call this to execute the model.

Can be called multiple times for retry logic.

Can skip calling it to short-circuit.

Returns:
`ModelCallResult`
"""
msg = (
"Asynchronous implementation of awrap_model_call is not available. "
"You are likely encountering this error because you defined only the sync version "
"(wrap_model_call) and invoked your agent in an asynchronous context "
"(e.g., using `astream()` or `ainvoke()`). "
"To resolve this, either: "
"(1) subclass AgentMiddleware and implement the asynchronous awrap_model_call method, "
"(2) use the @wrap_model_call decorator on a standalone async function, or "
"(3) invoke your agent synchronously using `stream()` or `invoke()`."
)
raise NotImplementedError(msg)

def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the agent execution completes."""

async def aafter_agent(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run after the agent execution completes."""

def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool execution for retries, monitoring, or modification.

Async version is `awrap_tool_call`

Multiple middleware compose automatically (first defined = outermost).

Exceptions propagate unless `handle_tool_errors` is configured on `ToolNode`.

Args:
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.

Access state via `request.state` and runtime via `request.runtime`.
handler: `Callable` to execute the tool (can be called multiple times).

Returns:
`ToolMessage` or `Command` (the final result).

The handler `Callable` can be invoked multiple times for retry logic.

Each call to handler is independent and stateless.
"""
msg = (
"Synchronous implementation of wrap_tool_call is not available. "
"You are likely encountering this error because you defined only the async version "
"(awrap_tool_call) and invoked your agent in a synchronous context "
"(e.g., using `stream()` or `invoke()`). "
"To resolve this, either: "
"(1) subclass AgentMiddleware and implement the synchronous wrap_tool_call method, "
"(2) use the @wrap_tool_call decorator on a standalone sync function, or "
"(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
)
raise NotImplementedError(msg)

async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Intercept and control async tool execution via handler callback.

The handler callback executes the tool call and returns a `ToolMessage` or
`Command`. Middleware can call the handler multiple times for retry logic, skip
calling it to short-circuit, or modify the request/response. Multiple middleware
compose with first in list as outermost layer.

Args:
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.

Access state via `request.state` and runtime via `request.runtime`.
handler: Async callable to execute the tool and returns `ToolMessage` or
`Command`.

Call this to execute the tool.

Can be called multiple times for retry logic.

Can skip calling it to short-circuit.

Returns:
`ToolMessage` or `Command` (the final result).

The handler `Callable` can be invoked multiple times for retry logic.

Each call to handler is independent and stateless.
"""
msg = (
"Asynchronous implementation of awrap_tool_call is not available. "
"You are likely encountering this error because you defined only the sync version "
"(wrap_tool_call) and invoked your agent in an asynchronous context "
"(e.g., using `astream()` or `ainvoke()`). "
"To resolve this, either: "
"(1) subclass AgentMiddleware and implement the asynchronous awrap_tool_call method, "
"(2) use the @wrap_tool_call decorator on a standalone async function, or "
"(3) invoke your agent synchronously using `stream()` or `invoke()`."
)
raise NotImplementedError(msg)

before_agent 方法

在Agent执行流程开始前调用,用于初始化、验证或修改初始状态。

1
2
3
4
5
6
7
8
9
10
11
12
def before_agent(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
"""
参数:
state: Agent当前状态,包含messages等信息
runtime: Agent运行时环境
返回:
可选的状态更新字典,或None表示不修改状态
"""
messages = state.get('messages', [])
if len(messages) == 0:
return None # 不修改状态
return {"context": {"message_category": "short"}} # 返回状态更新

方法解读

  • 作用:这是类继承方式中 AgentMiddleware 的一个钩子方法,在Agent 执行流程开始前被调用。

  • 参数

    • self:类实例本身,用于访问实例变量和其他方法。

    • state: AgentState:Agent 的当前状态,包含 messages(对话历史)等核心数据。

    • runtime: Any:Agent 运行时环境,提供配置、工具等上下文信息。

  • 返回值

    • dict[str, Any]:状态更新字典,用于修改 Agent 的状态。

    • None:表示不做任何修改,保持原始状态。

核心逻辑

  • 步骤 1:从 state 中获取 messages 列表,若不存在则默认空列表。

  • 步骤 2:判断对话历史是否为空:

    • 若为空(len(messages) == 0),返回 None,不修改状态。

    • 若不为空,返回状态更新字典,在 context 中新增 message_category: "short" 字段,标记当前消息为 “短消息” 类别。

after_agent方法

在Agent执行流程结束后调用,用于分析结果、记录日志或进行后续处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def after_agent(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
"""
参数:
state: Agent当前状态,包含更新后的messages等信息
runtime: Agent运行时环境
返回:
可选的状态更新字典,或None表示不修改状态
"""
messages = state.get('messages', [])
if messages:
last_msg = messages[-1]
if last_msg.__class__.__name__ == 'AIMessage':
return {"analysis": {"quality": "good"}}
return None

方法解读

  • 作用:这是类继承方式中 AgentMiddleware 的钩子方法,在Agent 执行流程结束后被调用,用于对最终状态做收尾处理、分析或记录。

  • 参数

    • self:类实例本身,用于访问实例变量和其他方法。

    • state: AgentState:Agent 的当前状态,包含执行完成后更新的 messages(完整对话历史)等数据。

    • runtime: Any:Agent 运行时环境,提供配置、工具等上下文信息。

  • 返回值

    • dict[str, Any]:状态更新字典,用于修改 Agent 的最终状态。

    • None:表示不做任何修改,保持最终状态。

核心逻辑

  • 步骤 1:从 state 中获取 messages 列表,若不存在则默认空列表。

  • 步骤 2:判断是否存在消息:

    • 若存在消息,取出最后一条消息 last_msg

    • 检查最后一条消息的类名是否为 AIMessage(即模型生成的回复)。

    • 如果是模型回复,则返回状态更新字典,在 analysis 中标记 quality: "good",表示本次响应质量良好。

  • 步骤 3:若不满足条件,返回 None,不修改状态。

before_model方法

在每次用语言模型前调用,用于处理输入、设置参数或进行安全检查。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def before_model(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
"""
参数:
state: 包含要发送给模型的messages等信息
runtime: Agent运行时环境
返回:
可选的状态更新字典,或None表示不修改状态
"""
messages = state.get('messages', [])
if messages:
last_msg = messages[-1]
if hasattr(last_msg, 'content'):
content = last_msg.content.lower()
if '你好' in content:
return {"model_settings": {"mode": "friendly"}}
return None

方法解读

  • 作用:这是类继承方式中 AgentMiddleware 的钩子方法,在调用语言模型之前执行,用于预处理输入、调整模型参数或控制流程。

  • 参数

    • self:类实例本身,用于访问实例变量和其他方法。

    • state: AgentState:包含即将发送给模型的 messages 等状态数据。

    • runtime: Any:Agent 运行时环境,提供配置、工具等上下文信息。

  • 返回值

    • dict[str, Any]:状态更新字典,可修改模型输入或配置。

    • None:表示不修改状态,直接调用模型。

核心逻辑

  • 步骤 1:从 state 中获取 messages 列表,若不存在则默认空列表。

  • 步骤 2:判断是否存在消息:

    • 若存在消息,取出最后一条用户消息 last_msg

    • 检查消息是否有 content 属性,若有则将内容转为小写。

    • 若内容中包含关键词 “你好”,则返回状态更新字典,设置 model_settings.mode = "friendly",告诉模型以友好模式响应。

  • 步骤 3:若不满足条件,返回 None,不修改状态。

after_model方法

在每次模型返回响应调用,用于分析响应、记录日志或进行后处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def after_model(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
"""
参数:
state: 包含模型返回消息的Agent状态
runtime: Agent运行时环境
返回:
可选的状态更新字典,或None表示不修改状态
"""
messages = state.get('messages', [])
if messages:
last_msg = messages[-1]
if last_msg.__class__.__name__ == 'AIMessage':
return {"response_quality": "normal"}
return None

方法解读

  • 作用:这是类继承方式中 AgentMiddleware 的钩子方法,在语言模型生成响应后执行,用于后处理模型输出、记录或分析结果。

  • 参数

    • self:类实例本身,用于访问实例变量和其他方法。

    • state: AgentState:包含模型返回消息的 Agent 状态,已更新最新的模型响应。

    • runtime: Any:Agent 运行时环境,提供配置、工具等上下文信息。

  • 返回值

    • dict[str, Any]:状态更新字典,可修改或补充模型响应数据。

    • None:表示不修改状态,直接将模型响应传递给后续流程。

核心逻辑

  • 步骤 1:从 state 中获取 messages 列表,若不存在则默认空列表。

  • 步骤 2:判断是否存在消息:

    • 若存在消息,取出最后一条消息 last_msg

    • 检查最后一条消息的类名是否为 AIMessage(即模型生成的回复)。

    • 如果是模型生成的响应,则返回状态更新字典,添加 response_quality: "normal" 字段,标记本次响应质量为 “普通”。

  • 步骤 3:若不满足条件,返回 None,不修改状态。

wrap_model_call方法

包装模型调用过程,用于控制、监控或修改模型请求和响应。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""
参数:
request: 模型请求对象,包含要发送给模型的内容
handler: 原始模型调用处理器函数
返回:
模型响应结果
"""
print(f"[{self.log_level}] 准备调用模型")
try:
result = handler(request) # 调用原始处理器
print(f"[{self.log_level}] 模型调用成功")
return result
except Exception as e:
print(f"[{self.log_level}] 模型调用异常: {e}")
raise

方法解读

  • 作用:这是类继承中间件中用于包装模型调用的核心方法,在模型请求发出前后执行自定义逻辑,实现对模型调用的完整拦截与监控。

  • 参数

    • self:类实例本身,可访问实例变量(如示例中的 log_level)。

    • request: ModelRequest:封装了要发送给模型的完整请求数据(如消息列表、模型参数等)。

    • handler: Callable[[ModelRequest], ModelResponse]:原始的模型调用处理函数,是被包装的核心逻辑。

  • 返回值ModelResponse,即模型返回的响应结果。

核心逻辑

  • 前置处理:在调用模型前,打印日志(包含实例的日志级别 self.log_level),提示 “准备调用模型”。
  • 核心调用:通过 handler(request) 执行原始的模型请求逻辑,获取响应结果。
  • 后置处理:调用成功后,打印 “模型调用成功” 日志,并将结果返回。
  • 异常处理:捕获调用过程中所有异常,打印包含异常信息的错误日志后,重新抛出异常,保证错误能被上层正常捕获。

wrap_tool_call方法

包装工具调用过程,用于控制、验证或修改工具调用请求和结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage],
) -> ToolMessage:
"""
参数:
request: 工具调用请求对象
handler: 原始工具调用处理器函数
返回:
工具执行结果消息
"""
tool_name = request.tool_call.get('name', '未知工具')
print(f"[{self.log_level}] 调用工具: {tool_name}")
return handler(request) # 调用原始处理器

方法解读

  • 作用:这是类继承中间件中用于包装工具调用的核心方法,在工具执行前后插入自定义逻辑,实现对工具调用的拦截与监控。

  • 参数

    • self:类实例本身,可访问实例变量(如示例中的 log_level)。

    • request: ToolCallRequest:封装了工具调用的完整请求数据(如工具名称、调用参数等)。

    • handler: Callable[[ToolCallRequest], ToolMessage]:原始的工具调用处理函数,是被包装的核心逻辑。

  • 返回值ToolMessage,即工具执行后返回的结果消息。

核心逻辑

  • 前置处理:从 request.tool_call 中提取工具名称(若不存在则默认 “未知工具”),并打印包含日志级别的调用日志,记录即将调用的工具。

  • 核心调用:通过 handler(request) 执行原始的工具调用逻辑,获取工具执行结果。

  • 返回结果:直接将工具执行结果返回,不做额外修改(可在此处扩展后置处理逻辑)。

完整的中间件示例类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# 完整的中间件示例类
from langchain.agents.middleware import AgentMiddleware, AgentState, ModelRequest, ModelResponse
from langchain.agents.middleware.types import ToolCallRequest
from langchain_core.messages import ToolMessage
from typing import Any, Callable

class CoreMiddlewareDemo(AgentMiddleware):
# 构造函数
def __init__(self, log_level: str = "INFO"):
self.log_level = log_level

# Agent生命周期钩子
def before_agent(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
messages = state.get('messages', [])
if len(messages) == 0:
print(f"[{self.log_level}] 没有初始消息")
return None
return {"context": {"message_count": len(messages)}}

def after_agent(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
print(f"[{self.log_level}] Agent执行完成")
return None

# Model生命周期钩子
def before_model(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
print(f"[{self.log_level}] 准备调用模型")
return None

def after_model(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
print(f"[{self.log_level}] 模型返回响应")
return None

# 调用包装钩子
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
print(f"[{self.log_level}] 包装模型调用")
return handler(request)

代码结构解读 📝

  1. 类定义CoreMiddlewareDemo 继承自 AgentMiddleware,是一个完整的中间件类,实现了多种生命周期钩子和包装器。
  2. 构造函数 __init__:初始化日志级别 log_level,默认值为 "INFO",用于统一控制日志输出。
  3. Agent 生命周期钩子
    • before_agent:在 Agent 执行前检查消息数量,记录日志并更新上下文。
    • after_agent:在 Agent 执行完成后打印日志,不修改状态。
  4. Model 生命周期钩子
    • before_model:在模型调用前打印日志。
    • after_model:在模型返回响应后打印日志。
  5. 调用包装钩子 wrap_model_call:包装模型调用,打印日志后执行原始调用逻辑。

最佳实践

选择合适的钩子类型 🎯

钩子类型 适用场景
Node-style 钩子before_agent/after_agent/before_model/after_model 需要在特定执行点执行简单逻辑,不需要控制执行流程,仅做状态修改或日志记录
Wrap-style 钩子wrap_model_call/wrap_tool_call 需要控制执行流程,例如实现重试、缓存、修改请求 / 响应对象等底层拦截逻辑
before_agent/after_agent 适合Agent 生命周期的初始化、清理、会话级别的全局操作
before_model/after_model 适合与模型交互相关的逻辑,如输入过滤、输出格式化、模型调用监控

选择合适的实现方式 🛠️

实现方式 适用场景
装饰器方式 简单中间件,仅需单个钩子,无需复杂配置或状态管理,代码轻量直观
类继承方式 复杂中间件,需要多个钩子、状态管理(如计数器、缓存)或自定义配置(如日志级别)
中间件工厂函数 需要根据动态参数创建不同行为的中间件,例如根据环境变量生成不同日志级别的中间件

常见注意事项

这是开发中间件时必须遵守的核心原则:

  • 避免阻塞操作:中间件中的长时间运行任务(如网络请求、大文件读写)会阻塞 Agent 流程,严重影响整体性能,应尽量异步化或轻量化。
  • 错误处理:必须适当捕获和处理异常,避免中间件内部错误导致整个代理崩溃,同时要保证异常信息能被上层感知。
  • 返回值处理:Node-style 钩子中,None 表示不修改状态、继续执行;返回字典则会更新状态,需谨慎设计返回逻辑。
  • 状态修改:修改 AgentState 时要小心,避免意外覆盖重要信息(如对话历史、上下文数据),推荐只新增或修改特定字段。
  • 中间件顺序:多个中间件的执行顺序非常关键,尤其是存在依赖关系时(例如日志中间件应在业务中间件之前执行)。

核心总结 ✨

  1. 先选钩子:根据要干预的生命周期阶段和是否需要控制流程,选择 Node-style 或 Wrap-style 钩子。
  2. 再选实现:根据复杂度选择装饰器、类继承或工厂函数方式。
  3. 最后避坑:遵守性能、错误、状态、顺序等注意事项,保证中间件稳定可靠。