0xf

日記だよ

ParamSpecで任意のパラメータをジェネリクス指定できる(python 3.10〜)

LangChain本読んでて、@chain の作りがスマートだなー! と思った。

langchain/libs/core/langchain_core/runnables at master · langchain-ai/langchain · GitHub 実装はゴツい!

ともあれ、デコレータで関数を返す以外の発想がなかった。こんな感じですよ。

from typing import Callable, Generic, TypeVar, Any, ParamSpec

P = ParamSpec('P')
R = TypeVar('R')

class Invoker(Generic[P,R]):
    f : Callable[P, R]
    def __init__(self, f: Callable[P, R]):
        self.f = f
    
    def invoke(self, *args: P.args, **kwargs: P.kwargs) -> R:
        print("invoke start")
        r =  self.f(*args, **kwargs)
        print("invoke done")
        return r

@Invoker
def f(a: int, b: int) -> int:
    return a + b

print(f.invoke(1, 2))

実行結果はこう

% python main.py 
invoke start
invoke done
3

ちゃんとIDEで型チェックも効いてる。

で、こういう作りにしておくと関数が任意のクラスのインスタンスとして振る舞うようになるので(もちろん __call__ を実装してもよい)、たとえばパイプ演算子に該当するマジックメソッド __or__ を実装しておけば、add(5) | mul(10) みたいな形で関数合成できるDSLぽい設計が可能だ。LangChain ではこれを使って書かせる風情がある。

def hoge():
  ...

  @chain
  def prompt(..):
    ...

  chain = prompt | model | logger
  chain.invoke("こんにちは")

かっこいいですが黒魔術ではありますね。でもかっこいいと思う。