Sunday, February 3, 2008

Higher order multivariate automatic differentiation in Haskell.

I've started a game in Haskell. I've run into a couple of situations where I need derivatives of multivariate equations. It's embarrassing how long i stared at this, Overloading Haskell numbers, part 2, Forward Automatic Differentiation. and this Non-standard analysis, automatic differentiation, Haskell, and other stories. without understanding what was going on. Finally I broke down and read this Lazy Multivariate Higher-Order Forward-Mode AD and this Lazy Differential Algebra and its Applications and the different approaches started to sink in. Like a first order level of understanding. I had a discussion with a physicist at work who kept looking at me like i had a second head when I was trying to explain the system. We talked through it, and finally things started to crystalize.

Looking back, as far as code changes go, this is trivial. Just substitute a set for a list, and provide indexed variables. Conceptually however, differentiation became very reminiscent of compilation. All my false starts at building up the sets of values remedied me of writing code for instruction selection. evaluating f(x) was similar to traversing an AST where actually building the sets for the derivatives was like code generation based on the AST. Blah blah blah, Math is hard, let's go shopping.


> module Diff where
> import qualified Data.IntMap as IM
> data Diff a = C a
> | D a (IM.IntMap (Diff a))
> deriving (Eq,Ord,Show)

I keep track of the different known variables in an intmap. Following after Karczmarczuk, I use crappy names. C is for Constant. D is for Differentiable expression. Every Diff has a value, Differentiable expressions have a map of variable id's to differentiable expressions.


> dVar x id = D x $ IM.fromList [(id,C 1)]

for example, with a simple variable like dVar 3 0 this build the structure for variable id 0. The first derivative with respect to id 0 is 1, with respect to anything else, 0


> instance Functor Diff where
> fmap f (C a) = C (f a)
> fmap f (D a set) = D (f a) (IM.map (fmap f) set)

Although I don't use fmap as often as i should, Karczmarczuk's code used it. It's such a good idea. I have a nagging suspicion that i should implement map and fold for random datatypes i implement, then i never get around to doing it. If nothing else, it gives a nice understanding of how to really use the thing.


> add (C x) (C y) = C (x + y)
> add (C x) (D y vars) = D (x+y) vars
> add (D x vars) (C y) = D (x+y) vars
> add (D x vx) (D y vy) = D (x+y) (IM.unionWith (+) vx vy)

Add is great. Anything involving a constant is trivial. anytime there's a variable collision, add the children. So clean, so simple. So, (dVar 1 0) + (dVar 1 0) yeilds D 2 (fromList [(0,C 2)]). f(x) = x + x evaluated at 1 is 2(x), the first derivative makes the single variable fall away, giving just 2.


> mul (C x) (C y) = C (x * y)
> mul (C x) v = fmap (x*) v
> mul v (C y) = fmap (*y) v
> mul p@(D x vx) q@(D y vy) = D (x*y) (IM.unions
> [(IM.intersectionWith (\a b-> a * q + p * b) vx vy),
> (IM.map (* q) (IM.difference vx vy)),
> (IM.map (p *) (IM.difference vy vx))
> ])

Mul took me a long time to get right. There's three cases to get just right. First, all of the overlapping derivatives must use the product rule. This is straight out of the list implementation. It wasn't at all obvious, to me, that the other cases are just multiplication. An afternoon with pencil an paper convinced me.


> df [] (C x) = Just x
> df [] (D x vars) = Just x
> df (i:is) (D x vars) = df is =<< (IM.lookup i vars)

df allows indexing into the resulting Diff structure. Nothing implies zero.


> test x y = x^3 * y + x^2 * y^2
> testdx x y = 3*x^2 * y + 2*x * y^2
> testdy x y = x^3 + x^2 * 2 * y
> testdx2 x y = 6*x * y + 2*y^2
> testdx2dy x y = 6*x + 4*y

Finaly some evaluations at 2 3.

*Diff> test 2 3
60
*Diff> df [] $ test (dVar 2 0) (dVar 3 1)
Just 60
*Diff> testdx 2 3
72
*Diff> df [0] $ test (dVar 2 0) (dVar 3 1)
Just 72
*Diff> testdy 2 3
32
*Diff> df [1] $ test (dVar 2 0) (dVar 3 1)
Just 32
-- here you can see the order of taking the derivitive dosn't matter.
*Diff> testdx2dy 2 3
24
*Diff> df [0,0,1] $ test (dVar 2 0) (dVar 3 1)
Just 24
*Diff> df [1,0,0] $ test (dVar 2 0) (dVar 3 1)
Just 24
*Diff> df [0,1,0] $ test (dVar 2 0) (dVar 3 1)
Just 24


Everything else is cribbed from people much smarter than me. As i said before, substituting a map of variables to their respective derivatives is the only change through the rest of the code. See the links at the beginning of this article for beautiful derivations of the following.


> instance (Num a) => Num (Diff a) where
> (+) = add
> (*)= mul
> negate = fmap negate
> signum (C x) = C $ signum x
> signum (D x _) = C $ signum x
> abs (C x) = C $ abs x
> abs p@(D x _) = f $ fmap (*(signum x)) p
> where f (D _ flip) = D (abs x) flip
> fromInteger x = C (fromInteger x)

> instance (Fractional a) => Fractional (Diff a) where
> recip (C x) = C (recip x)
> recip (D x vars) = D (recip x) (IM.map (* (C (recip x * recip x))) vars)
> fromRational r = C (fromRational r)

> lift (f:f') p@(D x vars) = D (f x) (IM.map (* lift f' p) vars)
> lift (f:f') p@(C x) = C (f x)

> instance (Floating a) => Floating (Diff a) where
> pi = C pi
> exp (C x) = C (exp x)
> exp (D x vars) = r where r = D (exp x) (IM.map (mul r) vars)
> log (C x) = C (log x)
> log p@(D x vars) = D (log x) (IM.map (*(recip p)) vars)
> sqrt (C x) = C (sqrt x)
> sqrt (D x vars) = r where r = D (sqrt x) (IM.map (* (recip (2*r))) vars)
> sin = lift (cycle [sin, cos, negate.sin, negate.cos])
> cos = lift (cycle [cos, negate.sin, negate.cos, sin])
> acos (C x) = C (acos x)
> acos p@(D x vars) = D (acos x) (IM.map (* (negate $recip $sqrt $1-p*p)) vars)
> asin (C x) = C (asin x)
> asin p@(D x vars) = D (asin x) (IM.map (* (recip $sqrt $1-p*p)) vars)
> atan (C x) = C (atan x)
> atan p@(D x vars) = D (atan x) (IM.map (* (recip $ p*p+1)) vars)
> sinh x = (exp x - exp (-x))/2
> cosh x = (exp x + exp (-x))/2
> asinh x = log (x + sqrt (x*x + 1))
> acosh x = log (x + sqrt (x*x - 1))
> atanh x = (log (1+x) - log(1-x))/2



Oh, by the way, this post, like the others is literate haskell, just paste it into a text file with a .lhs suffix and you can run it all.

Jason

-- edit -- added negate to the Num instance, as turingtest suggested. Thanks.
-- edit -- flipped sine in atan as luis suggested.

6 comments:

Joe Fredette said...

Awesome, this is wicked cool stuff.

TuringTest said...

You need to defined the negate part of Num:

negate = fmap negate

Luis Armendariz said...

For taking the derivative of atan, don't you mean "recip $ p*p+1"?

meteficha said...

How about using it with http://hackage.haskell.org/packages/archive/numbers/2008.4.20.1/doc/html/Data-Number-Symbolic.html ?

How about including it in http://hackage.haskell.org/cgi-bin/hackage-scripts/package/numbers ?

Nice!

Unknown said...

recip might contain an error: is it not perhaps


recip (D x vars) = D (recip x) (IM.map (* (C (- recip x * recip x))) vars)

Unknown said...

sorry - I mean that recip should be:


recip p@(D x vars) = D (recip x) (IM.map (* (- recip (p * p))) vars)