From d9cd1ef7e06c21fe1efe40a57299f816a72d6647 Mon Sep 17 00:00:00 2001 From: David Thompson Date: Fri, 19 Aug 2022 14:20:04 -0400 Subject: Step 10: Proper Tail Calls --- compiler.scm | 104 +++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 98 insertions(+), 6 deletions(-) (limited to 'compiler.scm') diff --git a/compiler.scm b/compiler.scm index 01bc40e..14160bb 100644 --- a/compiler.scm +++ b/compiler.scm @@ -77,6 +77,9 @@ (define (funcall? x) (op-eq? x 'funcall)) +(define (tail-call? x) + (op-eq? x 'tail-call)) + (define (lambda? x) (op-eq? x 'lambda)) @@ -126,6 +129,21 @@ (define (funcall-arguments x) (drop x 2)) +(define (tail-call-proc x) + (second x)) + +(define (tail-call-arguments x) + (drop x 2)) + +(define (code-vars x) + (second x)) + +(define (code-free-vars x) + (third x)) + +(define (code-body x) + (fourth x)) + (define unique-counter (make-parameter 0)) (define (unique-number) @@ -555,12 +573,42 @@ (unless (zero? stack-start) (emit-sub (immediate stack-start) rsp)) ; restore stack pointer (emit-mov (offset (- si wordsize) rsp) rdi)) ; restore closure pointer - (begin + (begin ; eval argument (emit-expr (first args) si* env) (emit-mov rax (offset si* rsp)) (loop (cdr args) (- si* wordsize)))))) +(define (emit-tail-call proc args si env) + (let loop ((args* args) + (si* si)) + (if (null? args*) + (let ((stack-start si)) + (emit-expr proc si env) ; eval closure + (emit-sub (immediate closure-tag) rax) ; untag it to get pointer + (emit-mov rax rdi) ; store pointer in destination register + ;; Copy all of the args from their current stack locations + ;; at the top of the stack to the bottom of the stack. + ;; Function calls are expecting to find the values of their + ;; arguments starting from the bottom of the stack, so we + ;; need to set things up as if we incremented the stack + ;; pointer and made a 'call', but really we are doing a + ;; 'jmp'. I feel like we're playing a trick on the + ;; function. A very neat trick. :) + (let copy-loop ((args* args) + (si* si)) + (unless (null? args*) + (emit-mov (offset si* rsp) rax) ; copy from top of stack... + (emit-mov rax (offset (- si* si wordsize) rsp)) ; ...to the bottom + (copy-loop (cdr args*) (- si* wordsize)))) + (emit-jmp (string-append "*" (offset 0 rdi))) + (emit-mov (offset (- si wordsize) rsp) rdi)) ; restore closure pointer + (begin ; eval argument + (emit-expr (first args*) si* env) + (emit-mov rax (offset si* rsp)) + (loop (cdr args*) + (- si* wordsize)))))) + (define (emit-closure lvar vars si env) (let ((label (lookup lvar env))) (emit-lea (offset label rip) rax) ; first word of closure points to label @@ -599,7 +647,9 @@ ((closure? x) (emit-closure (second x) (cddr x) si env)) ((funcall? x) - (emit-funcall (lvar x) (arguments x) si env)) + (emit-funcall (funcall-proc x) (funcall-arguments x) si env)) + ((tail-call? x) + (emit-tail-call (tail-call-proc x) (tail-call-arguments x) si env)) (else (error "unknown expression" x)))) @@ -616,7 +666,8 @@ (let ((label (first lvar)) (code (third lvar))) (if (code? code) - (emit-code label (second code) (third code) (fourth code) env*) + (emit-code label (code-vars code) (code-free-vars code) + (code-body code) env*) (error "expected a code expression" code)))) lvars*) (emit-label "scheme_entry") @@ -730,10 +781,44 @@ (let-values (((new-x labels) (iter x))) `(labels ,labels ,new-x))) +(define (mark-tail-calls x) + (define (maybe-mark x) + (if (funcall? x) + `(tail-call ,@(cdr x)) + (mark-tail-calls x))) + (cond + ((immediate? x) x) + ((variable? x) x) + ((closure? x) x) + ((if? x) + `(if ,(mark-tail-calls (test x)) + ,(maybe-mark (consequent x)) + ,(maybe-mark (alternate x)))) + ((let? x) + `(let ,(map (lambda (binding) + (list (lhs binding) + (mark-tail-calls (rhs binding)))) + (bindings x)) + ,(maybe-mark (body x)))) + ((primcall? x) + (cons (primcall-op x) + (map mark-tail-calls (cdr x)))) + ((funcall? x) + `(funcall ,@(map mark-tail-calls (cdr x)))) + ((code? x) + `(code ,(code-vars x) ,(code-free-vars x) ,(mark-tail-calls (fourth x)))) + ((labels? x) + `(labels ,(map (lambda (binding) + (list (lhs binding) + (mark-tail-calls (rhs binding)))) + (bindings x)) + ,(mark-tail-calls (body x)))))) + (define (expand x) (parameterize ((unique-counter 0)) - (replace-lambdas-with-closures-and-funcalls - (annotate-free-variables x)))) + (mark-tail-calls + (replace-lambdas-with-closures-and-funcalls + (annotate-free-variables x))))) (define (compile-program x) (let ((x* (expand x))) @@ -837,4 +922,11 @@ (test-case '(let ((x 5)) (let ((f (lambda (y) (+ x y)))) (f 4))) - "9")) + "9") + ;; recursive tail calls + (test-case '(let ((f (lambda (x f) + (if (= x 5) + 789 + (f (add1 x) f))))) + (f 0 f)) + "789")) -- cgit v1.2.3