Saturday, April 28, 2012

First-Class Everything: Loops, Tail Calls, and Continuations

      From where we left off in "Schrodinger's Equation of Software", we had first-class functions, first-class syntax, and first-class environments. Since we can build anything out of lambdas, and we can build lambdas out of vau and wrap, our language ought to be complete, right? Well, yes, if we actually had a perfect implementation of the mathematical ideal of vau calculus.
      Unfortunately, the limitations of Python's stack mean that we can't use recursive loops. There are two ways to solve this: we can add another built-in function just to handle loops, or we can separate the interpreted stack from the Python stack and make tail calls work.
      The simplest way to make tail calls work is to use trampolining. So, we'll define a special interpreter data structure for a tail call that encapsulates everything we'll need to execute the function call (i.e., all of the arguments to eval), and return that instead of the value of the function call whenever we have a function call in tail position.

class Tail():
    def __init__(self,expr,env):
        self.expr = expr
        self.env = env

    def __iter__(self):
        yield self.expr
        yield self.env

      Then, we modfy eval to check if it got a value or a tail call instruction every time it executes a procedure. If it did get a tail call, it replaces its arguments with the data in the tail call object and loops:

def eval(x, env):
    "Evaluate an expression in an environment."
    while True:
        if isa(x, Symbol):          # variable reference
            return env[x]
        elif isa(x, list):          # (proc exp*)
            proc = eval(x[0], env)
            if hasattr(proc, '__call__'):
                val = proc(env,*x[1:])
                if isa(val, Tail):
                    x, env = val
                else:
                    return val
            else:
                raise ValueError("%s = %s is not a procedure" %
                                  (to_string(x[0]),to_string(proc)))
        else: return x

      Note that eval should *never* return a tail call object; it's not a value in the language, it's just bookkeeping for the interpreter. That works pretty well; if you write proper tail-recursive loops, they'll run just fine. But we can do better. For one thing, there's still a recursive call to eval in eval. For another, not all calls to eval in the builtin functions are tail calls. E.g., here's the definition of "if":

    'if': lambda v,z,t,f: Tail((t if eval(z,v) else f), v)

      The middle call to eval can't be replaced by Tail. The basic binary operators can't be fixed at all, and these sequences of eval -> procedure -> eval can still result in blowing up the stack. Ideally, we want to eliminate any recursive calls to eval, so no matter what happens to the interpreted language stack, Python's stack doesn't grow. We could make sure that *all* function calls are in tail position by transforming our interpreted programs into CPS- but that would add a whole lot complexity to our interpreter to do the transformation. We can get the same effect by noting that the continuation of an interpreted expression is exactly the same as the continuation of the interpreter when it's interpreting that expression; therefore, we just have to CPS the interpreter. And since our eval function is tiny, that's no trouble at all!
      The CPSed eval function looks like this:

def eval(x, env, k=lambda x:x):
    "Evaluate an expression in an environment."
    val = None
    while True:
        if isa(x, Symbol):          # variable reference
            val = k(env[x])
        elif isa(x, list):          # (proc exp*)
            def capture_args():
                cx, cv, ck = x[1:], env, k
                def try_call(proc):
                    if hasattr(proc, '__call__'):
                        return proc(ck,cv,*cx)
                    raise ValueError("%s = %s is not a procedure" %
                                      (to_string(cx[0]),to_string(proc)))
                return try_call
            x, k = x[0], capture_args()
            continue
        else:
            val = k(x)
        if not isa(val, Tail):
            return val
        (x,env,k) = val

      Of course, we also have to re-write all of our builtin functions, which are now called with an extra continuation parameter, in CPS, but we'll get to that later. Now, the k parameter to eval isn't a real continuation, because Python doesn't have them; it's a callback that actually does return a value, but the value it returns is the value of executing the continuation. At the top level, when calling eval from outside, there's no continuation to process values, so the callback is just the identity function. Once again, eval can never return a Tail, nor can it return a continuation- but continuations can return Tails (in fact, they usually will), so we have to make sure that we check that whenever we get a value via a call to k.
      Function application is a little bit special; we build our own continuation in the eval function that will actually do the procedure call, using a Python closure to save the necessary state, then loop to evaluate the procedure itself. We also avoid the overhead of constructing and destructing a Tail object since we already have all the bits that we need right there, and the environment stays the same. Of course, we also have to update the definition of a Tail, since we've increased the number of parameters to eval:

class Tail():
    def __init__(self,expr,env,k):
        self.expr = expr
        self.env = env
        self.k = k

    def __iter__(self):
        yield self.expr
        yield self.env
        yielf self.k

      To help support recursion, we'll update the way that closure __call__s are handled a little bit:

def __call__(self, k, call_env, *args):
    new_env = Env(zip(self.vars, args), self.clos_env)
    new_env[self.sym] = call_env
    if not 'self' in args: new_env['self'] = self #safe recursion
    return Tail(self.body, new_env, k)

      Now, every procedure gets access to a special "extra parameter" called "self" that refers to the currently executing function, unless it's overwritten by an explicit parameter of the same name. This ensures that recursion will still work even if the function is re-assigned to a different name, or if it's anonymous, without having to bother with the Y-combinator. Now we can write a simple infinite loop:

((vau (a) %
    (begin
        (define va (eval % a))
        (print va)
        (self (+ va 1))))
    0)

      And it works! A never-ending stream of increasing numbers printed to the console, and Python doesn't complain about blowing up the stack.
      Back to CPSing the built-in functions. Only a few of them are actually interesting: define, begin, list, and wrap. Define is interesting because it has side-effects: it modifies the environment. Other side-effectful functions (like set! and print) are essentially identical in structure:

def defvar(k,v,var,e):
    def def_k(val):
        v[var] = val
        return k(val)
    return Tail(e,v,def_k)

      It just builds a continuation that performs its side-effects before passing the return value onto the external continuation.
      Begin has to build a chain of continuations that will evaluate successive expressions in an arbitrarily long list until it gets to the end and passes the value of that expression to the external continuation. It looks like this:

def sequence(k,v,*x):
    if len(x) == 0: return k(None)
    return Tail(x[0],v,k if len(x) == 1 else lambda vx: sequence(k,v,*x[1:]))

      The apparent recursive call to sequence does not actually risk exploding the stack, becuase it's not called from here; it's only called as the value of k somewhere in eval. Thus, it only grows the stack by two frames: one for the lambda that throws away intermediate values, and one for the next call to sequence, which will return without actually calling itself. The implementation of "cond" is similar.
      List and wrap initially seem similar to begin. However, they have to actually save intermediate results, and that makes a big difference. List is implemented with this function to map evaluation loops over a list of argument expressions:

def cps_map_eval(k,v,*x):
    vx = []
    arglen = len(x)
    def map_loop(i):
        if i == arglen: return k(vx)
        else:
            def assign_val(vmx):
                vx.append(vmx)
                return map_loop(i+1)
            return Tail(x[i],v,assign_val)
    return map_loop(0)

      It builds a chain of continuations that tack the value of an expression onto the end of the mapped argument list and then evaluates the next one. Wrap uses that same function to evaluate arguments to a wrapped procedure:

def wrap(k,v,p):
    return Tail(p,v,
        lambda vp:
            k(lambda ck,cv,*x:
                cps_map_eval(lambda vx: vp(ck,cv,*vx),cv,*x)))

      Since, after CPSing, every call to eval in the builtin functions is transformed into constructing a Tail object, CPSing our interpreter has also nicely eliminated the dependency between the definition of the global environment and the interpreter itself.
      Now that we have proper tail calls and a continuation chain that's independent of the Python stack, we're just about ready to implement first-class continuations! There is, however, one little problem: if we give interpreted programs access to their continuations, anything that tries to activate a continuation into a function argument list will break cps_map_eval. The continuation will tack new values onto the old argument list, re-evaluating any arguments after the one whose continuation was called, and try to continue to the next procedure with too many arguments!
      To fix this, we need to ensure two things: 1) argument evaluation order is arbitrary and unimportant; 2) calls to continuations that return to a function argument slot fill in the value of that slot, and only that slot. The new version of cps_map_eval that accomplishes those goals looks like this:

def cps_map_eval(k,v,*x):
    """
    Evaluates the elements of an argument list,
    creating continuations that will assign values
    to the correct indices in the evaluated list.
    """
    arglen = len(x)
    if arglen == 0: return k([])
    argv = [None]*arglen
    done = [False]
    def map_loop(i):
        if i == arglen:
            done[0] = True
            return k(argv)
        else:
            def assign_val(vmx):
                if not done[0]:     #on the first time through,
                    argv[i] = vmx   #evaluate the next argument in the list
                    return map_loop(i+1)
                else: #if this is a continuation call,
                    new_argv = argv[:]   #copy the other argument values
                    new_argv[i] = vmx    #and just overwrite this one
                    return k(new_argv)
            return Tail(x[i],v,assign_val)
    return map_loop(0)

      Note that "done" is a single-element list rather than a simple boolean variable due to Python's scoping rules- closures can't mutate variables in higher scopes, but they can mutate the contents of containers. With just a little more work, this could be further modified to ensure that argument evaluation occurs in parallel.
      Now that our continuations are safe, we need to decide how to give the interpreted language access to them. It seems a little disingenuous to make calls to continuations look just like calls to functions, but since we turned all of our syntax into function calls, we don't really have any special syntax that we could use that would look different. So, we'll just go with that. We'll wrap up interpreter callbacks in a special callable Continuation object for the interpreted language:

class Continuation():
    def __init__(self,k):
        self.k = k

    def __call__(self, call_k, call_env, *args):
        if len(args) != 1:
            raise Exception("Continuations take exactly 1 argument.")
        return self.k(args[0])

      We could just use a Python closure for that, but making it a class of it's own results in nicer debugging output. When called, it throws out the current environment and continuation, and activates the saved callback instead. This version of Continuation acts like a vau expression- it doesn't evaluate its argument! We could use wrap on a continuation value, but it's probably better if explicit arguments to continuations behave the same way as implicit return values, so here's an alternate version that evaluates the continuation's argument:

class Continuation():
    def __init__(self,k):
       self.k = k

    def __call__(self, call_k, call_env, *args):
        if len(args) != 1:
            raise Exception("Continuations take exactly 1 argument.")
        return Tail(args[0],call_env,self.k)

      This even looks more like what it actually does: returns to another iteration of the eval loop, but replacing the current continuation with the saved continuation.
      We still need to provide a way for programs to get references to continuations. Rather than adding an extra built-in form like call/cc to bind continuations to the environment, we'll just make another minor edit to how closures are __call__ed:

def __call__(self, k, call_env, *args):
    new_env = Env(zip(self.vars, args), self.clos_env)
    new_env[self.sym] = call_env
    if not 'self' in args: new_env['self'] = self #safe recursion
    if not 'return' in args: new_env['return'] = Continuation(k)
    return Tail(self.body, new_env, k)

      Every function gets a reference to the current continuation at its call site as an extra over-writable argument, just like "self", which it can return or pass on to other function calls.
      Here are a couple of simple programs to verify that it really works:

(print ((vau () %
    (begin
        (+ 1 2)
        (return 7)
        19))))

      Prints 7, not 19, because the current continuation that would return 19 is thrown out and replaced.

(define c nil)
(define mul (lambda (a b) (* a b)))
(print
    (mul 
        ((vau () % (begin (set! c return) 2)))
        (+ 2 3)))
(c 3)
(c (+ 1 2))

      Prints 10, 15, 15, restarting the calculation with a different value for the first argument each time c is called.

As before, working code can be found at https://github.com/gliese1337/schrodinger-lisp/.

No comments:

Post a Comment