(use-modules (ice-9 format) (ice-9 match) (srfi srfi-1) (srfi srfi-9) (srfi srfi-9 gnu)) (define-syntax-rule (define-matcher (name rules ...) body ...) (define name (match-lambda* ((rules ...) body ...) (_ 'unmatched)))) (define-syntax-rule (compose-matchers matcher ...) (lambda args (let loop ((matchers (list matcher ...))) (match matchers (() (error "unmatched" args)) ((m . rest) (let ((result (apply m args))) (if (eq? result 'unmatched) (loop rest) result))))))) ;;; ;;; Typed expression annotation ;;; (define-record-type (make-primitive-type name) primitive-type? (name primitive-type-name)) (define (display-primitive-type type port) (format port "#" (primitive-type-name type))) (set-record-type-printer! display-primitive-type) (define int-type (make-primitive-type 'int)) (define float-type (make-primitive-type 'float)) (define bool-type (make-primitive-type 'bool)) (define-record-type (make-type-variable name) type-variable? (name type-variable-name)) (define (display-type-variable type port) (format port "#" (type-variable-name type))) (set-record-type-printer! display-type-variable) (define unique-counter (make-parameter 0)) (define (unique-number) (let ((n (unique-counter))) (unique-counter (+ n 1)) n)) (define* (unique-identifier #:optional (prefix 'T)) (string->symbol (format #f "~a~a" prefix (unique-number)))) (define (fresh-type-variable) (make-type-variable (unique-identifier))) (define-record-type (make-procedure-type parameter-types return-types) procedure-type? (parameter-types procedure-type-parameter-types) (return-types procedure-type-return-types)) (define (display-procedure-type type port) (format port "#" (procedure-type-parameter-types type) (procedure-type-return-types type))) (set-record-type-printer! display-procedure-type) (define (type? obj) (or (primitive-type? obj) (type-variable? obj) (procedure-type? obj))) (define (top-level-env) '()) (define (lookup var env) (or (assq-ref env var) (error "unbound variable" var))) (define (make-texp types exp) `(t ,types ,exp)) (define (texp? obj) (match obj (('t _ _) #t) (_ #f))) (define (texp-types texp) (match texp (('t types _) types))) (define (texp-exp texp) (match texp (('t _ exp) exp))) (define (display-typed-expression texp port) (format port "#" (texp-types texp) (texp-exp texp))) (set-record-type-printer! display-typed-expression) (define-matcher (annotate:bool (? boolean? b) env) (make-texp (list bool-type) b)) (define-matcher (annotate:int (and (? number?) (? exact-integer? n)) env) (make-texp (list int-type) n)) (define-matcher (annotate:float (and (? number?) (? inexact? n)) env) (make-texp (list float-type) n)) (define-matcher (annotate:var (? symbol? var) env) (make-texp (list (lookup var env)) var)) (define-matcher (annotate:if ('if predicate consequent alternate) env) (make-texp (list (fresh-type-variable)) `(if ,(annotate-exp predicate env) ,(annotate-exp consequent env) ,(annotate-exp alternate env)))) (define-matcher (annotate:lambda ('lambda ((? symbol? args) ...) body) env) (define parameter-types (map (lambda (_name) (fresh-type-variable)) args)) (define env* (append (map cons args parameter-types) env)) (define body* (annotate-exp body env*)) (define return-types (map (lambda (_type) (fresh-type-variable)) (texp-types body*))) (make-texp (list (make-procedure-type parameter-types return-types)) `(lambda ,args ,body*))) (define-matcher (annotate:call (operator args ...) env) (define operator* (annotate-exp operator env)) (make-texp (map (lambda (_type) (fresh-type-variable)) (texp-types operator*)) (cons operator* (map (lambda (arg) (annotate-exp arg env)) args)))) (define annotate-exp (compose-matchers annotate:bool annotate:int annotate:float annotate:var annotate:if annotate:lambda annotate:call)) (define (annotate-exp* exp) (parameterize ((unique-counter 0)) (annotate-exp exp (top-level-env)))) ;;; ;;; Constraints ;;; (define-record-type (constrain lhs rhs) constraint? (lhs constraint-lhs) (rhs constraint-rhs)) (define (display-constraint constraint port) (format port "#" (constraint-lhs constraint) (constraint-rhs constraint))) (set-record-type-printer! display-constraint) (define-matcher (constrain:other ((? type? _) ...) _) '()) (define-matcher (constrain:if ((? type? types) ...) ('if (? texp? predicate) (? texp? consequent) (? texp? alternate))) (append (list (constrain (list bool-type) (texp-types predicate)) (constrain types (texp-types consequent)) (constrain types (texp-types alternate))) (program-constraints predicate) (program-constraints consequent) (program-constraints alternate))) (define-matcher (constrain:lambda ((? procedure-type? type)) ('lambda ((? symbol? args) ...) (? texp? body))) (cons (constrain (procedure-type-return-types type) (texp-types body)) (program-constraints body))) (define-matcher (constrain:call ((? type? types) ...) ((? texp? operator) (? texp? operands) ...)) (cons (constrain (texp-types operator) (list (make-procedure-type (map texp-type operands) types))) (append (program-constraints operator) (append-map program-constraints operands)))) (define %program-constraints (compose-matchers constrain:if constrain:lambda constrain:call constrain:other)) (define (program-constraints texp) (%program-constraints (texp-type texp) (texp-exp texp))) ;;; ;;; Occurs check ;;; (define-matcher (occurs:default _ _) #f) (define-matcher (occurs:list (a rest-a ...) (b rest-b ...)) (or (occurs-in? a b) (occurs-in? rest-a rest-b))) (define-matcher (occurs:variable (? type-variable? a) (? type-variable? b)) (eq? a b)) (define-matcher (occurs:procedure (? type-variable? v) (? procedure-type? p)) (or (occurs-in? v (procedure-type-parameter-types p)) (occurs-in? v (procedure-type-return-types p)))) (define occurs-in? (compose-matchers occurs:list occurs:variable occurs:procedure occurs:default)) ;;; ;;; Unification ;;; (define (substitute-term term dict) (match dict (() term) (((from . to) . rest) (if (eq? term from) to (substitute-term term rest))))) (define (substitute-dict var term dict) (map (match-lambda ((a . b) (cons (if (eq? a var) term a) (if (eq? b var) term b)))) dict)) (define (substitute var other dict) (let ((other* (substitute-term other dict))) (and (not (occurs-in? var other*)) (alist-cons var other* (substitute-dict var other* dict))))) (define %unify-prompt-tag (make-prompt-tag 'unify)) (define (unify-fail) (pk 'unify-fail) (abort-to-prompt %unify-prompt-tag)) (define (maybe-substitute var other dict) (cond ;; Tautology: matching 2 vars that are the same. ((and (type-variable? other) (eq? var other)) dict) ;; Variable has been bound to some other value, recursively follow ;; it and unify. ((assq-ref dict var) => (lambda (type) (unify type other dict))) ;; Substitute variable for value. (else (or (substitute var other dict) (unify-fail))))) (define (constant? obj) (and (not (type-variable? obj)) (not (procedure-type? obj)) (not (list? obj)))) (define-matcher (unify:constants a b dict) (pk 'unify:constants a b) (if (eqv? a b) dict (unify-fail))) (define-matcher (unify:lists (a rest-a ...) (b rest-b ...) dict) (pk 'unify:lists (cons a rest-a) (cons b rest-b)) (let ((dict* (unify a b dict))) (unify rest-a rest-b dict*))) (define-matcher (unify:variable-left (? type-variable? a) b dict) (pk 'unify:variable-left a b) (maybe-substitute a b dict)) (define-matcher (unify:variable-right a (? type-variable? b) dict) (pk 'unify:variable-right a b) (maybe-substitute b a dict)) (define-matcher (unify:procedures (? procedure-type? a) (? procedure-type? b) dict) (pk 'unify:procedures a b) (define dict* (unify (procedure-type-parameter-types a) (procedure-type-parameter-types b) dict)) (unify (procedure-type-return-types a) (procedure-type-return-types b) dict*)) (define unify (compose-matchers unify:variable-left unify:variable-right unify:procedures unify:lists unify:constants)) (define (unify-constraints constraints) (call-with-prompt %unify-prompt-tag (lambda () (unify (map constraint-lhs constraints) (map constraint-rhs constraints) '())) (lambda (_k) #f))) ;;; ;;; Type Resolution ;;; (define-matcher (resolve:primitive x dict) x) (define-matcher (resolve:primitive-type (? primitive-type? type) dict) type) (define-matcher (resolve:type-variable (? type-variable? var) dict) (let ((type (assq-ref dict var))) (if (or (not type) (type-variable? type)) (error "cannot determine type" var) type))) (define-matcher (resolve:procedure-type (? procedure-type? type) dict) (make-procedure-type (resolve-types (procedure-type-parameter-types type) dict) (resolve-types (procedure-type-return-types type) dict))) (define-matcher (resolve:list (exps ...) dict) (map (lambda (texp) (resolve-types texp dict)) exps)) (define resolve-types (compose-matchers resolve:primitive-type resolve:type-variable resolve:procedure-type resolve:list resolve:primitive)) (define (infer-types exp) (let* ((texp (annotate-exp* exp)) (constraints (pk 'constraints (program-constraints texp))) (substitutions (unify-constraints constraints))) (and substitutions (resolve-types texp substitutions)))) (infer-types #t) (infer-types 6) (infer-types 6.5) (infer-types '(if #t 1 2)) (infer-types '((lambda (x) x) 1)) ;;(infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x)))