跳过正文

PocketFlow核心框架

··303 字·2 分钟
loveRyujin
作者
loveRyujin
一名普普通通的go后端工程师
目录

最近在github上发现一个有趣的项目PocketFlow,一个总代码量在100行左右的 LLM 框架,使用 python 编写,零依赖、简洁的实现。阅读了一下项目代码,非常适合学习,故以此篇文章总结分析一下框架源码。

PocketFlow 的核心基础组件主要有:BaseNode、Node、Flow,然后在此基础之上衍生出BatchNode、BatchFlow、AsyncNode等等。

我理解 PocketFlow 的核心就是两件事:

  • 图编排:将各个 Node 通过边组合起来,边即为各类 action。
  • 共享数据:Node 之间通过 SharedData 共享数据,是一个公用的存储。

BaseNode
#

BaseNode 是最基础的父类。定义如下:

class BaseNode:
    def __init__(self): self.params,self.successors={},{}
    def set_params(self,params): self.params=params
    def next(self,node,action="default"):
        if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
        self.successors[action]=node; return node
    def prep(self,shared): pass
    def exec(self,prep_res): pass
    def post(self,shared,prep_res,exec_res): pass
    def _exec(self,prep_res): return self.exec(prep_res)
    def _run(self,shared): p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e)
    def run(self,shared): 
        if self.successors: warnings.warn("Node won't run successors. Use Flow.")  
        return self._run(shared)
    def __rshift__(self,other): return self.next(other)
    def __sub__(self,action):
        if isinstance(action,str): return _ConditionalTransition(self,action)
        raise TypeError("Action must be a string")

class _ConditionalTransition:
    def __init__(self,src,action): self.src,self.action=src,action
    def __rshift__(self,tgt): return self.src.next(tgt,self.action)

BaseNode 定义节点的生命周期规范、执行能力和连接不同节点的能力。

生命周期三阶段:

  • prep:从共享存储 shared 中获取特定数据,进行预处理。比如查询数据库、读取文件或者将数据序列化为字符串。
  • exec:拿到 prep 阶段的结果,执行计算逻辑。比如可以执行 LLM 调用、工具调用等。
  • post:结合 prep 和 exec 阶段的结果,进行后置处理,并将特定数据写回共享存储 shared 中。比如更新数据库、记录结果等等。 三个阶段默认实现都是 pass,需要子类去重写实现。

执行入口:

  • run:运行节点。
  • _run:默认顺序调用 prep、_exec、post 三阶段。
  • _exec:默认调用 exec 阶段。

连接节点能力:

  • __rshift__:重载 » 运算符,表示关联另一个 node 作为下一个节点,通过 next 方法,默认action 为 default。比如 a » b,即a.rshift(b)。
  • __sub__:重载 - 运算符,表示节点之间通过特定 action 关联起来。比如 a - “action” » b,会先构造 _ConditionalTransition,也就是条件转换,然后调用它的 rshirt 方法,利用 action 将 a、b 关联起来。

Node
#

Node 是 BaseNode 的进一步封装,在 BaseNode 的基础上增加了重试和兜底机制。

class Node(BaseNode):
    def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait
    def exec_fallback(self,prep_res,exc): raise exc
    def _exec(self,prep_res):
        for self.cur_retry in range(self.max_retries):
            try: return self.exec(prep_res)
            except Exception as e:
                if self.cur_retry==self.max_retries-1: return self.exec_fallback(prep_res,e)
                if self.wait>0: time.sleep(self.wait)
  • max_retries:最大重试次数。
  • wait:等待的时间间隔。
  • exec_fallback:最大重试次数过后走到的兜底逻辑。默认情况下只是重新抛出异常。

Flow
#

Flow 编排节点为图,最简单的可以构造一个串行序列,然后在此基础之上可以添加分支、引入循环。

class Flow(BaseNode):
    def __init__(self,start=None): super().__init__(); self.start_node=start
    def start(self,start): self.start_node=start; return start
    def get_next_node(self,curr,action):
        nxt=curr.successors.get(action or "default")
        if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
        return nxt
    def _orch(self,shared,params=None):
        curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
        while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
        return last_action
    def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
    def post(self,shared,prep_res,exec_res): return exec_res
  • _run:Flow 核心逻辑不是计算而是编排,所以运行时阶段主要包括 prep、_orch、post。
  • get_next_node: 根据当前节点的 post 阶段返回的 action 跳转到下一个节点。
  • _orch:具体的编排方法,负责循环执行整个图。

注意的是,Flow 继承 BaseNode,本身也可以抽象视作一个节点。

BatchNode
#

BatchNode 继承自 Node,重写其 _exec 方法。

class BatchNode(Node):
    def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]
  • _exec:重写父类 Node 的 _exec 方法,接受 items ,对 items 的每一个元素执行 Node 的 _exec方法,即重试+兜底逻辑。

参考
#

  1. PocketFlow官方文档