(use-modules (ice-9 format) (ice-9 match) (srfi srfi-1) (srfi srfi-9) (srfi srfi-9 gnu)) (define-syntax-rule (define-matcher (name rules ...) body ...) (define name (match-lambda* ((rules ...) body ...) (_ 'unmatched)))) (define-syntax-rule (compose-matchers matcher ...) (lambda args (let loop ((matchers (list matcher ...))) (match matchers (() (error "unmatched" args)) ((m . rest) (let ((result (apply m args))) (if (eq? result 'unmatched) (loop rest) result))))))) ;;; ;;; Typed expression annotation ;;; (define-record-type (make-primitive-type name) primitive-type? (name primitive-type-name)) (define (display-primitive-type type port) (format port "#" (primitive-type-name type))) (set-record-type-printer! display-primitive-type) (define int-type (make-primitive-type 'int)) (define float-type (make-primitive-type 'float)) (define bool-type (make-primitive-type 'bool)) (define-record-type (make-type-variable name) type-variable? (name type-variable-name)) (define (display-type-variable type port) (format port "#" (type-variable-name type))) (set-record-type-printer! display-type-variable) (define unique-counter (make-parameter 0)) (define (unique-number) (let ((n (unique-counter))) (unique-counter (+ n 1)) n)) (define* (unique-identifier #:optional (prefix 'T)) (string->symbol (format #f "~a~a" prefix (unique-number)))) (define (fresh-type-variable) (make-type-variable (unique-identifier))) (define-record-type (make-procedure-type arg-types return-type) procedure-type? (arg-types procedure-type-arg-types) (return-type procedure-type-return-type)) (define (display-procedure-type type port) (format port "#" (procedure-type-arg-types type) (procedure-type-return-type type))) (set-record-type-printer! display-procedure-type) (define (type? obj) (or (primitive-type? obj) (type-variable? obj) (procedure-type? obj))) (define (top-level-env) '()) (define (lookup var env) (or (assq-ref env var) (error "unbound variable" var))) (define (make-texp type exp) `(t ,type ,exp)) (define (texp? obj) (match obj (('t _ _) #t) (_ #f))) (define (texp-type texp) (match texp (('t type _) type))) (define (texp-exp texp) (match texp (('t _ exp) exp))) (define (display-typed-expression texp port) (format port "#" (texp-type texp) (texp-exp texp))) (set-record-type-printer! display-typed-expression) (define-matcher (annotate:bool (? boolean? b) env) (make-texp bool-type b)) (define-matcher (annotate:int (and (? number?) (? exact-integer? n)) env) (make-texp int-type n)) (define-matcher (annotate:float (and (? number?) (? inexact? n)) env) (make-texp float-type n)) (define-matcher (annotate:var (? symbol? var) env) (make-texp (lookup var env) var)) (define-matcher (annotate:if ('if predicate consequent alternate) env) (make-texp (fresh-type-variable) `(if ,(annotate-exp predicate env) ,(annotate-exp consequent env) ,(annotate-exp alternate env)))) (define-matcher (annotate:lambda ('lambda ((? symbol? args) ...) body) env) (define arg-types (map (lambda (_name) (fresh-type-variable)) args)) (define env* (append (map cons args arg-types) env)) (define return-type (fresh-type-variable)) (make-texp (make-procedure-type arg-types return-type) `(lambda ,args ,(annotate-exp body env*)))) (define-matcher (annotate:call (operator args ...) env) (make-texp (fresh-type-variable) (cons (annotate-exp operator env) (map (lambda (arg) (annotate-exp arg env)) args)))) (define annotate-exp (compose-matchers annotate:bool annotate:int annotate:float annotate:var annotate:if annotate:lambda annotate:call)) (define (annotate-exp* exp) (parameterize ((unique-counter 0)) (annotate-exp exp (top-level-env)))) ;;; ;;; Constraints ;;; (define-record-type (constrain lhs rhs) constraint? (lhs constraint-lhs) (rhs constraint-rhs)) (define (display-constraint constraint port) (format port "#" (constraint-lhs constraint) (constraint-rhs constraint))) (set-record-type-printer! display-constraint) (define-matcher (constrain:other (? type? _) _) '()) (define-matcher (constrain:if (? type? type) ('if (? texp? predicate) (? texp? consequent) (? texp? alternate))) (append (list (constrain bool-type (texp-type predicate)) (constrain type (texp-type consequent)) (constrain type (texp-type alternate))) (program-constraints predicate) (program-constraints consequent) (program-constraints alternate))) (define-matcher (constrain:lambda (? procedure-type? type) ('lambda ((? symbol? args) ...) (? texp? body))) (cons (constrain (procedure-type-return-type type) (texp-type body)) (program-constraints body))) (define-matcher (constrain:call (? type? type) ((? texp? operator) (? texp? operands) ...)) (cons (constrain (texp-type operator) (make-procedure-type (map texp-type operands) type)) (append (program-constraints operator) (append-map program-constraints operands)))) (define %program-constraints (compose-matchers constrain:if constrain:lambda constrain:call constrain:other)) (define (program-constraints texp) (%program-constraints (texp-type texp) (texp-exp texp))) ;;; ;;; Unification ;;; (define-matcher (occurs:default _ _) #f) (define-matcher (occurs:variable (? type-variable? a) (? type-variable? b)) (eq? a b)) (define-matcher (occurs:procedure (? type-variable? v) (? procedure-type? p)) (or (any (lambda (arg-type) (occurs-in? v arg-type)) (procedure-type-arg-types p)) (occurs-in? v (procedure-type-return-type p)))) (define occurs-in? (compose-matchers occurs:variable occurs:procedure occurs:default)) (define (substitute-term term dict) (match dict (() term) (((from . to) . rest) (if (eq? term from) to (substitute-term term rest))))) (define (substitute-dict var term dict) (map (match-lambda ((a . b) (cons (if (eq? a var) term a) (if (eq? b var) term b)))) dict)) (define (substitute var term dict) (let ((term* (substitute-term term dict))) (and (not (occurs-in? var term*)) (alist-cons var term* (substitute-dict var term* dict))))) (define (maybe-substitute var-first terms) (lambda (dict succeed fail) (match var-first ((var . rest-vars) (match terms ((term . rest-terms) (cond ;; Tautology: matching 2 vars that are the same. ((and (type-variable? term) (eq? var term)) (succeed dict fail rest-vars rest-terms)) ;; Variable has been bound to some other value, recursively ;; follow it and unify. ((assq-ref dict var) => (lambda (type) ((unify:dispatch (cons type rest-vars) terms) dict succeed fail))) ;; Substitute variable for value. (else (let ((dict* (substitute var term dict))) (if dict* (succeed dict* fail rest-vars rest-terms) (fail))))))))))) (define (constant? obj) (and (not (type-variable? obj)) (not (procedure-type? obj)) (not (list? obj)))) (define-matcher (unify:fail a b) (define (fail-unifier dict succeed fail) (fail)) fail-unifier) (define-matcher (unify:constants ((? constant? a) . rest-a) ((? constant? b) . rest-b)) (define (constant-unifier dict succeed fail) (if (eqv? a b) (begin (succeed dict fail rest-a rest-b)) (fail))) constant-unifier) (define-matcher (unify:lists (a . rest-a) (b . rest-b)) (define (list-unifier dict succeed fail) ((unify:dispatch a b) dict (lambda (dict* fail* _null-a _null-b) (succeed dict* fail* rest-a rest-b)) fail)) list-unifier) (define-matcher (unify:variable-left (and ((? type-variable? _) . _) a) b) (maybe-substitute a b)) (define-matcher (unify:variable-right a (and ((? type-variable? _) . _) b)) (maybe-substitute b a)) (define-matcher (unify:procedures ((? procedure-type? a) . rest-a) ((? procedure-type? b) . rest-b)) (define (procedure-unifier dict succeed fail) ((unify:dispatch (cons (procedure-type-return-type a) (procedure-type-arg-types a)) (cons (procedure-type-return-type b) (procedure-type-arg-types b))) dict (lambda (dict* fail* _null-a _null-b) (succeed dict* fail* rest-a rest-b)) fail)) procedure-unifier) (define %unify:dispatch (compose-matchers unify:constants unify:variable-left unify:variable-right unify:procedures unify:lists unify:fail)) (define (unify:dispatch a b) (define (dispatcher dict succeed fail) (if (and (null? a) (null? b)) (succeed dict fail a b) ((%unify:dispatch a b) dict (lambda (dict* fail* rest-a rest-b) ((unify:dispatch rest-a rest-b) dict* succeed fail*)) fail))) dispatcher) (define (%unify a b dict succeed) ((unify:dispatch (list a) (list b)) dict (lambda (dict fail rest-a rest-b) (or (and (null? rest-a) (null? rest-b) (succeed dict)) (fail))) (lambda () #f))) (define (unify a b) (%unify a b '() identity)) (define (unify-constraints constraints) (unify (map constraint-lhs constraints) (map constraint-rhs constraints))) ;;; ;;; Type Resolution ;;; (define (primitive? x) (or (number? x) (boolean? x) (symbol? x))) (define-matcher (resolve:primitive (? primitive? x) dict) x) (define-matcher (resolve:primitive-type (? primitive-type? type) dict) type) (define-matcher (resolve:type-variable (? type-variable? var) dict) (let ((type (assq-ref dict var))) (if (or (not type) (type-variable? type)) (error "cannot determine type" var) type))) (define-matcher (resolve:procedure-type (? procedure-type? type) dict) (make-procedure-type (resolve-types (procedure-type-arg-types type) dict) (resolve-types (procedure-type-return-type type) dict))) (define-matcher (resolve:list (exps ...) dict) (map (lambda (texp) (resolve-types texp dict)) exps)) (define resolve-types (compose-matchers resolve:primitive resolve:primitive-type resolve:type-variable resolve:procedure-type resolve:list)) (define (infer-types exp) (let* ((texp (annotate-exp* exp)) (constraints (program-constraints texp)) (substitutions (unify-constraints constraints))) (unless substitutions (error "type mismatch" texp)) (resolve-types texp substitutions))) (infer-types #t) (infer-types 6) (infer-types 6.5) (infer-types '(if #t 1 2)) (infer-types '((lambda (x) x) 1)) (infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x)))