1:- module(autodiff2, [max/3, mul/3, add/3, pow/3, exp/2, llog/2, log/2, lse/2, deriv/3, back/1, grad/1,
2 esc/3, expand_wsums/0, wsum/2, add_to_wsum/3, gather_ops/3]).
9:- use_module(library(chr)). 10:- use_module(library(rbutils)). 11:- use_module(library(listutils), [measure/2]). 12:- use_module(library(dcg_pair)). 13:- use_module(library(dcg_macros)). 14
15:- chr_constraint expand_wsums, wsum(?,-), add_to_wsum(?,?,-), ops(-,+).
16:- chr_constraint max(?,?,-), add(?,?,-), mul(?,?,-), llog(-,-), log(-,-), exp(-,-), pow(+,-,-),
17 lse(?,-), stoch_exp(?,-), stoch_exp(?,+,-), mes(?,-,-,-), chi(?,?,?,-),
18 deriv(?,-,?), agg(?,-), acc(?,-), acc(-), go, esc(+,?,-).
19
20add_to_wsum(X,0.0,S) <=> ord_list_to_rbtree([X-1], Terms), wsum(Terms, S).
21add_to_wsum(X,S1,S2), wsum(Terms1, S1) <=> incr_term(X, Terms1, Terms2), wsum(Terms2, S2).
22
23incr_term(X) --> rb_app_or_new(X, succ, =(1)).
24add_mul(X-N, S1, S2) :- K is float(N), mul(K,X,NX), add(NX,S1,S2).
25expand_wsums \ wsum(Terms, Sum) <=> rb_fold(add_mul, Terms, 0.0, Sum).
26expand_wsums <=> true.
27
28% operations interface with simplifications
29mul(0.0,_,Y) <=> Y=0.0.
30mul(_,0.0,Y) <=> Y=0.0.
31mul(1.0,X,Y) <=> Y=X.
32mul(X,1.0,Y) <=> Y=X.
33mul(X,Y,Z1) \ mul(X,Y,Z2) <=> Z1=Z2.
34pow(K,X,Y) <=> K =:= 1 | Y=X. % guard to match floats and ints
35pow(0,_,Y) <=> Y=1.
36add(0.0,X,Y) <=> Y=X.
37add(X,0.0,Y) <=> Y=X.
38add(X,Y,Z1) \ add(X,Y,Z2) <=> Z1=Z2.
39
40% lse: log(sum(map(exp,Xs))), stoch_exp: stoch(map(exp,Xs))
41% mes: max, exp, sum - used to share computation of max(Xs), exp(Exs-Max) and sum
42lse([X],Y) <=> X=Y.
43lse(Xs,Y1) \ lse(Xs,Y2) <=> Y1=Y2.
44lse(Xs,_) ==> mes(Xs,_,_,_).
45stoch_exp(Xs,Ys1) \ stoch_exp(Xs,Ys2) <=> Ys1=Ys2.
46stoch_exp(Xs,Ys) ==> mes(Xs,_,_,_), measure(Xs,Ns), maplist(stoch_exp(Xs),Ns,Ys).
47mes(Xs,M1,Ws1,S1) \ mes(Xs,M2,Ws2,S2) <=> M1=M2, Ws1=Ws2, S1=S2.
48
49% propagate derivatives through unary and binary operators
50deriv(L,X,DX) \ deriv(L,X,DX1) <=> DX=DX1.
51deriv(L,_,DX) <=> ground(L) | DX=0.0.
52deriv(L,L,DL) ==> DL=1.0.
53deriv(_,_,DX) ==> var(DX) | acc(DX).
54deriv(L,Y,DY), pow(K,X,Y) ==> deriv(L,X,DX), dpow(K,X,Z), mul(DY,Z,T), agg(T,DX).
55deriv(L,Y,DY), exp(X,Y) ==> deriv(L,X,DX), mul(Y,DY,T), agg(T,DX).
56deriv(L,Y,DY), llog(Y,X) ==> deriv(L,X,DX), mul(Y,DY,T), agg(T,DX).
57deriv(L,Y,DY), log(X,Y) ==> deriv(L,X,DX), pow(-1,X,RX), mul(RX,DY,T), agg(T,DX).
58deriv(L,Y,DY), add(X1,X2,Y) ==> maplist(agg_add(L,DY),[X1,X2]).
59deriv(L,Y,DY), mul(X1,X2,Y) ==> maplist(agg_mul(L,DY),[X1,X2],[X2,X1]).
60deriv(L,Y,DY), max(X1,X2,Y) ==> maplist(agg_max(L,DY),[X1,X2],[X2,X1]).
61deriv(L,Y,DY), lse(Xs,Y) ==> stoch_exp(Xs,Ps), maplist(agg_mul(L,DY),Xs,Ps).
62deriv(L,Y,DY), stoch_exp(Xs,N,Y) ==>
63 pow(2,Y,Y2), mul(-1.0,Y2,NY2),
64 mul(DY,NY2,T1), mul(DY,Y,T2),
65 maplist(deriv(L),Xs,DXs), % !!! NB the rest is wrong for any constants in Xs
66 maplist(agg(T1),DXs),
67 nth1(N,DXs,DXN),
68 agg(T2,DXN).
69
70dpow(K,X,T) :- K1 is K - 1, KK is float(K), pow(K1,X,XpowK1), mul(KK,XpowK1,T).
71agg_max(L,DY,X1,X2) :- var(X1) -> deriv(L,X1,DX1), chi(X1,X2,DY,T1), agg(T1,DX1); true.
72agg_mul(L,DY,X1,X2) :- var(X1) -> deriv(L,X1,DX1), mul(X2,DY,T1), agg(T1,DX1); true.
73agg_add(L,DY,X1) :- var(X1) -> deriv(L,X1,DX1), agg(DY,DX1); true.
74acc(X) \ acc(X) <=> true.
75
77back(Y) :- var(Y) -> diff(Y), go; true.
78diff(Y) :- deriv(Y,Y,1.0).
79grad(Ys) :- maplist(diff,Ys), go.
80
81acc(X,S1), agg(Z,X) <=> add(Z,S1,S2), acc(X,S2).
82acc(X,S) <=> S=X.
83
84go \ deriv(_,_,_) <=> true.
85go \ acc(DX) <=> acc(DX,0.0).
86go <=> true.
87
88:- meta_predicate upd_ops(//,?,?). 89upd_ops(Upd,G1,G3) :- call(Upd,G1,G2), ops(G2,G3).
90op(Op, Ins, Outs) --> [op(Op,Ins,Outs)].
91
92ops(G1,G2), add(X,Y,Z) <=> upd_ops(op(add, [X,Y], [Z]), G1, G2).
93ops(G1,G2), mul(X,Y,Z) <=> upd_ops(op(mul, [X,Y], [Z]), G1, G2).
94ops(G1,G2), max(X,Y,Z) <=> upd_ops(op(max, [X,Y], [Z]), G1, G2).
95ops(G1,G2), pow(X,Y,Z) <=> upd_ops(op(pow, [X,Y], [Z]), G1, G2).
96ops(G1,G2), log(X,Y) <=> upd_ops(op(log, [X], [Y]), G1, G2).
97ops(G1,G2), exp(X,Y) <=> upd_ops(op(exp, [X], [Y]), G1, G2).
98ops(G1,G2), esc(Op,X,Y)<=> upd_ops(op(Op, X, Y), G1, G2).
99ops(_,_) \ llog(_,_) <=> true.
100
101ops(_,_) \ stoch_exp(_,_,_) <=> true.
102mes(Xs,M,_,S) \ ops(G1,G2), lse(Xs,Y) <=> mes(Xs,M,_,S), upd_ops(add_log(S,M,Y), G1, G2).
103mes(Xs,_,Ws,S) \ ops(G1,G2), stoch_exp(Xs,Ys) <=> upd_ops(divby_list(S,Ws,Ys), G1, G2).
104ops(G1,G2), mes(Xs,M,Ws,S) <=> upd_ops(max_exp_sum(Xs,M,Ws,S),G1,G2).
105ops(G1,G2), chi(X,Y,Z,I) <=> upd_ops(op(chi, [X,Y,Z], [I]), G1, G2).
106ops(G1,G2) <=> G1=G2.
107
108add_log(S,M,Y) --> op(add_log,[M,S],[Y]).
109divby_list(S,Ws,Ys) --> foldl(divby(S), Ws, Ys).
110divby(S,W,Y) --> op(div, [W,S], [Y]).
111
112max_exp_sum(Xs,M,Ws,Sum) -->
113 op(max_list, Xs, [M]),
114 foldl(exp_sub(M),Xs,Ws),
115 op(sum_list, Ws, [Sum]).
116exp_sub(M,X,Y) --> op(exp_sub, [M,X], [Y]).
117
118gather_ops(Ins, Outs, Sorted) :-
119 ops(Ops,[]), rb_empty(E),
120 foldl(back_links, Ops, E, BS),
121 traverse(BS, Ins, Outs, Sorted-E, []-_).
122
123back_links(Edge) --> {Edge=op(_,_,Outs)}, foldl(back_link(Edge), Outs).
124back_link(Edge, Out) --> rb_add(Out, Edge).
125traverse(BS, Ins, Outs) --> \> foldl(insert, Ins), foldl(eval(BS), Outs).
126insert(X) --> rb_add(X,t).
127
128eval(BS, Var) -->
129 ( ({nonvar(Var)}; \>
Reverse mode automatic differentatin using CHR.
Todo:
*/