state monad 推导
2025-06-12
函数式语言里面不带副作用,实现状态修改需要一些技巧性的写法,也就是 state monad。
下用面一个例子来推导 state monad。先来一个假设的需求,我们要将一个 sexp 里面出现的所有数字 +1,同时,还需统计这个 sexp 里面出现的符号 a 的次数。
在 cora 语言里面是没有局部变量修改操作,它是一门函数式的语言,虽然不像 haskell 那么纯函数式。 它的符号是可以重绑定的,以及 vector 是可以修改的。所以有绕过去的写法,通过符号的重赋值:
(set '*var* (gensym))
(func parse
'a => (begin
(set *var* (+ (value *var*) 1))
'a)
x => (+ x 1) where (number? x)
[x . xs] => [(parse x) . (parse xs)]
x => x)
(set *var* 0)
(parse `(a 3 (4 2 5 (a) b (d (c)) 7 2)))
(value *var*)
这里的 var 是一个符号,通过 (value var) 和 (set var xxx) 来实现类似全局变量的效果,通过它统计 a 的次数。 这种写法的代码很糟糕,它是带副作用的。
怎么样写不带副作用的代码呢?我们需要给 parse 加一个参数,它既是一个传入参数,也是一个返回值参数:
(func parse
'a s => [(+ s 1) 'a]
x s => [s (+ x 1)] where (number? x)
[x . xs] s => (match (parse x s)
[s1 x1]
(match (parse xs s1)
[s2 xs1]
[s2 [x1 . xs1]]))
x s => [s x])
使用的时候每次都需要传入参数 s,这样就没有副作用了
(let s 0
(parse `(a 3 (4 2 5 (a) b (d (c)) 7 2)) s))
s 其实是一个状态,调用 parse 函数其实是需要修改这个状态。parse 是一个"有副作用" 的方法,它需要返回 parse 后的结果,并且同时需要修改状态 s。实现方式是传入旧的状态,返回新的状态这种形式的 "修改",所以它又是无副作用的。
对于每个有状态的函数,我们都要额外传状态参数,这其实会让人很恼火。怎么样把这个状态参数消除掉呢?
我们可以先进行 curry 变换,把 s 从参数挪出去:
(func parse
'a => (lambda (s) [(+ s 1) 'a])
x => (lambda (s) [s (+ x 1)]) where (number? x)
[x . xs] => (lambda (s)
(match (parse x s)
[s1 x1]
(match (parse xs s1)
[s2 xs1]
[s2 [x1 . xs1]])))
x => (lambda (s) [s x]))
((parse '(a 3 (4 2 5 (a) b (d (c)) 7 2))) 0)
接下来,我们重点为看 (parse x s) 的返回处理,
(match (parse x s)
[s1 x1]
...body)
这是一个多值返回,在多值返回那篇博客中,我们记得返回多个值 s1 x1 等价于接受一个 k 然后调用 (k s1 x1)。
(match (parse x s)
[s1 x1]
(k s1 x1))
k = (lambda (x1)
(lambda (s1)
... ;; 接受多值后的处理,body))
于是前面的 parse 函数我们就可以改写成:
(func parse
'a => (lambda (s) [(+ s 1) 'a])
x => (lambda (s) [s (+ x 1)]) where (number? x)
[x . xs] => (lambda (s)
(match (parse x s)
[s1 x1]
(k x1 s1)))
x => (lambda (s) [s x]))
k = (lambda (x1)
(lambda (s1)
(match (parse xs s1)
[s2 xs1]
[s2 [x1 . xs1]])))
我们当然可以很自由地修改 (parse x s) 这种变成 ((parse x) s),(k x1 s1) 也可以 ((k x1) s1),或者如果我们不喜欢 k 这个名字,我们也可以改名叫 f,都是 curry 变换和 alpha 变换。于是 parse 可以变成这样子:
(func parse
'a => (lambda (s) [(+ s 1) 'a])
x => (lambda (s) [s (+ x 1)]) where (number? x)
[x . xs] => (lambda (s)
(match ((parse x) s)
[s1 x1]
((f x1) s1)))
x => (lambda (s) [s x]))
f = (lambda (x1)
(lambda (s1)
(match (parse xs s1)
[s2 xs1]
[s2 [x1 . xs1]])))
我们把其中 [x . xs] 处理分支拧出来看:
(lambda (s)
(match ((parse x) s)
[s1 x1]
((f x1) s1)))
把 (parse x) 改成参数 m
(lambda (s)
(match (m s)
[s1 x1]
((f x1) s1)))
m = (parse x)
于是我们重新发明了 state monad 的 bind 函数:
(defun bind (m f)
(lambda (s)
(match (m s)
[s1 v]
((f v) s1))))
我们可以用 bind 改写 parse 函数:
(func parse
'a => (lambda (s) [(+ s 1) 'a])
x => (lambda (s) [s (+ x 1)]) where (number? x)
[x . xs] => (bind (parse x) f)
x => (lambda (s) [s x]))
f = (lambda (x1)
(lambda (s1)
(match (parse xs s1)
[s2 xs1]
[s2 [x1 . xs1]])))
其实 f 也是可以用 bind 继续改写的,注意 f 的 body 部分,变换技巧也是跟前面一样,把 match 之后的多值返回,返回多个值等价于传递一个连续参数 k1:
(lambda (s1)
(match ((parse xs) s1)
[s2 xs1]
(k1 xs1 s2)))
k1 = (lambda (xs1)
(lambda (s2)
[s2 [x1 . xs1]]))
于是变成了 (bind f k1)
也就是说前面的 parse 可以最终写成这个样子:
(func parse
'a => (lambda (s) [(+ s 1) 'a])
x => (lambda (s) [s (+ x 1)]) where (number? x)
[x . xs] => (bind (parse x)
(lambda (x1)
(bind (parse xs)
(lambda (xs1)
(lambda (s2)
[s2 [x1 . xs1]])))))
x => (lambda (s) [s x]))
我们定义 state monad 的 return:
(defun return (v)
(lambda (s)
[s v]))
于是 parse 可以继续用 return 改写:
(func parse
'a => (lambda (s) [(+ s 1) 'a])
x => (return (+ x 1) where (number? x)
[x . xs] => (bind (parse x)
(lambda (x1)
(bind (parse xs)
(lambda (xs1)
(return [x1 . xs1]))))))
x => (return x))
大功告成!这已经是改成用 monad 写法了。再看一看 state monad 是什么:
(defun bind (m f)
(lambda (s)
(match (m s)
[s1 v]
((f v) s1))))
(defun return (v)
(lambda (s)
[s v]))
bind 会接受一个 state monad,以及处理从 monad 提取到的值的函数 f,它会返回一个新的 state monad。 而 state monad 是什么呢?是一个接受一个状态,返回新的状态和新的值的 closure。从 return 的定义来看这一点理解得更清晰。
如果用闭包和对象等价的角度来看("对象是穷人的闭包"),state monad 是一个对象,"这个对象里面有一个 field 存了状态,当在这个对象上面执行操作的时候,不是直接修改这个对象的 field,而是返回一个新的对象,新的对象中的状态被更新了"。是这样的么?不是!状态是需要传进去的,而不是存储在对象的 field 里面的。需要给这个对象传一个 state,就可以"激活"这个对象了。其实值才是对象中的 field,而状态不是。
(m s) 让 monad 接受一个状态,就会返回状态和值。用 match (m s) 去处理返回的状态和值,值传递给 f 去生成新的 monad,新的 monad (f v) 再接受老的状态 s1 更新到新的状态和新的值。好绕!
2025-06-12 更新
发现一个更简单的推导
最初的 (set 'var (gensym)) 的方式是不纯的,因此需要改写 parse 成一个纯函数。
parse 首先需要补一个参数 s,这个输入的状态,也是将要返回的状态。 由于要同时返回新状态,以及 parse 完的结果,这是一个多值返回,用 cps 的技巧来搞多值返回,需要再被一个参数 k
于是 parse 的写法:
(func parse
'a s k => (k (+ s 1) 'a)
x s k => (k s (+ x 1)) where (number? x)
[x . xs] s k => (parse x s
(lambda (s1 x1)
(parse xs s1
(lambda (s2 xs1)
(k s2 [x1 . xs1]])))))
x s k => (k s x))
把 s 和 k 两个参数都 curry 化:
(func parse
'a => (return 'a)
x => (lambda (s k) (k s (+ x 1))) where (number? x)
[x . xs] => (lambda (s k)
(parse x s
(lambda (s1 x1)
(parse xs s1
(lambda (s2 xs1)
(k s2 [x1 . xs1]]))))))
x => (return x))
定义 return
(defun return (x)
(lambda (s k)
(k s x)))
用 return 改写 parse:
(func parse
'a => (return 'a)
x => (lambda (s k) (k s (+ x 1))) where (number? x)
[x . xs] => (lambda (s k)
(parse x s
(lambda (s1 x1)
(parse xs s1
(lambda (s2 xs1)
(k s2 [x1 . xs1]]))))))
x => (return x))
好像卡住了,怎么继续推导出:
(defun bind (m f)
(lambda (s)
((m s)
(lambda (s1 v)
(f v s1)))))