;; features: ;; - multiple return values ;; - for all types ;; - intersection types (operator overloading) (use-modules (ice-9 format) (ice-9 match) (srfi srfi-1) (srfi srfi-9) (srfi srfi-9 gnu)) (define (print-list lst) (for-each (lambda (x) (format #t "~a\n" x)) lst)) (define (difference a b) (match a (() b) ((x . rest) (if (memq x b) (difference rest (delq x b)) (cons x (difference rest b)))))) (define-syntax-rule (define-matcher (name rules ...) body ...) (define name (match-lambda* ((rules ...) body ...) (_ 'unmatched)))) (define-syntax-rule (compose-matchers matcher ...) (lambda args (let loop ((matchers (list matcher ...))) (match matchers (() (error "unmatched" args)) ((m . rest) (let ((result (apply m args))) (if (eq? result 'unmatched) (loop rest) result))))))) ;;; ;;; Types and environments ;;; (define-record-type (make-primitive-type name) primitive-type? (name primitive-type-name)) (define (display-primitive-type type port) (format port "#" (primitive-type-name type))) (set-record-type-printer! display-primitive-type) (define int-type (make-primitive-type 'int)) (define float-type (make-primitive-type 'float)) (define bool-type (make-primitive-type 'bool)) (define-record-type (make-type-variable name) type-variable? (name type-variable-name)) (define (display-type-variable type port) (format port "#" (type-variable-name type))) (set-record-type-printer! display-type-variable) (define unique-counter (make-parameter 0)) (define (unique-number) (let ((n (unique-counter))) (unique-counter (+ n 1)) n)) (define* (unique-identifier #:optional (prefix 'T)) (string->symbol (format #f "~a~a" prefix (unique-number)))) (define (fresh-type-variable) (make-type-variable (unique-identifier))) (define-record-type (make-procedure-type parameter-types return-types) procedure-type? (parameter-types procedure-type-parameter-types) (return-types procedure-type-return-types)) (define (display-procedure-type type port) (format port "#" (procedure-type-parameter-types type) (procedure-type-return-types type))) (set-record-type-printer! display-procedure-type) (define-record-type (make-for-all-type variables type) for-all-type? (variables for-all-type-variables) (type for-all-type-type)) (define-record-type (make-intersection-type types) intersection-type? (types intersection-type-types)) (define (type? obj) (or (primitive-type? obj) (type-variable? obj) (procedure-type? obj) (intersection-type? obj))) (define (top-level-env) `((+ . ,(make-intersection-type (list (make-procedure-type (list int-type int-type) (list int-type)) (make-procedure-type (list float-type float-type) (list float-type))))))) (define (lookup var env) (let ((type (assq-ref env var))) (cond ((type? type) type) ((for-all-type? type) (instantiate type)) (else (error "unbound variable:" var))))) (define (occurs-in? a b) (cond ((and (type-variable? a) (type-variable? b)) (eq? a b)) ((and (type-variable? a) (procedure-type? b)) (or (occurs-in? a (procedure-type-parameter-types b)) (occurs-in? a (procedure-type-return-types b)))) ((and (type? a) (list? b)) (any (lambda (b*) (occurs-in? a b*)) b)) (else #f))) (define (apply-substitution-to-type type from to) (cond ((primitive-type? type) type) ((type-variable? type) (if (eq? type from) to type)) ((procedure-type? type) (make-procedure-type (map (lambda (param-type) (apply-substitution-to-type param-type from to)) (procedure-type-parameter-types type)) (map (lambda (return-type) (apply-substitution-to-type return-type from to)) (procedure-type-return-types type)))))) (define (apply-substitutions-to-type type subs) (let loop ((type type) (subs subs)) (match subs (() type) (((from . to) . rest) (loop (apply-substitution-to-type type from to) rest))))) (define (apply-substitution-to-env from to env) (map (match-lambda ((a . b) (cons (if (eq? a from) to a) (if (eq? b from) to b)))) env)) (define (substitute from to env) (let* ((to* (apply-substitutions-to-type to env))) (and (not (occurs-in? from to*)) (let ((env* (apply-substitution-to-env from to* env))) (pk 'add-to-env from to*) (alist-cons from to* env*))))) (define (free-variables-in-type type) (cond ((primitive-type? type) '()) ((type-variable? type) type) ((procedure-type? type) (delete-duplicates (append (map free-variables-in-type (procedure-type-parameter-types type)) (map free-variables-in-type (procedure-type-return-types type))))))) (define (free-variables-in-for-all for-all) (difference (for-all-type-variables for-all) (free-variables-in-type (for-all-type-type for-all)))) (define (free-variables-in-env env) (delete-duplicates (let loop ((env env)) (match env (() '()) (((_ . type) . rest) (cond ((type-variable? type) (cons (free-variables-in-type type) (loop rest))) ((for-all-type? type) (cons (free-variables-in-for-all type) (loop rest))) (else (loop rest)))))))) (define (instantiate for-all) (define subs (map (lambda (var) (cons var (fresh-type-variable))) (for-all-type-variables for-all))) (pk 'instantiate for-all) (apply-substitutions-to-type (for-all-type-type for-all) subs)) (define (generalize type env) (define quantifiers (difference (free-variables-in-type type) (free-variables-in-env env))) (pk 'quantifiers quantifiers) (if (null? quantifiers) type (make-for-all-type quantifiers type))) (define (constraints->env constraints) (map (lambda (c) (cons (constraint-lhs c) (constraint-rhs c))) constraints)) ;;; ;;; Typed expression annotation ;;; (define (make-texp types exp) `(t ,types ,exp)) (define (texp? obj) (match obj (('t _ _) #t) (_ #f))) (define (texp-types texp) (match texp (('t types _) types))) (define (texp-exp texp) (match texp (('t _ exp) exp))) (define (single-type texp) (match (texp-types texp) ((type) type) (_ (error "expected only 1 type" texp)))) (define-matcher (annotate:bool (? boolean? b) env) (make-texp (list bool-type) b)) (define-matcher (annotate:int (and (? number?) (? exact-integer? n)) env) (make-texp (list int-type) n)) (define-matcher (annotate:float (and (? number?) (? inexact? n)) env) (make-texp (list float-type) n)) (define-matcher (annotate:var (? symbol? var) env) (make-texp (list (lookup var env)) var)) (define-matcher (annotate:if ('if predicate consequent alternate) env) (make-texp (list (fresh-type-variable)) `(if ,(annotate-exp predicate env) ,(annotate-exp consequent env) ,(annotate-exp alternate env)))) (define-matcher (annotate:let ('let (((? symbol? vars) vals) ...) body) env) (define vals* (map (lambda (val) (annotate-exp val env)) vals)) (define val-types (map (lambda (val) (generalize (single-type val) env)) vals*)) (define vals** (map (lambda (val type) (make-texp (list type) (texp-exp val))) vals* val-types)) (define env* (append (map cons vars val-types) env)) (define body* (annotate-exp body env*)) (make-texp (texp-types body*) `(let ,(map list vars vals**) ,body*))) (define-matcher (annotate:begin ('begin exps ... last) env) (define last* (annotate-exp last env)) (make-texp (texp-types last*) `(begin ,@(map (lambda (exp) (annotate-exp exp env)) exps) ,last*))) (define-matcher (annotate:lambda ('lambda ((? symbol? args) ...) body) env) (define parameter-types (map (lambda (_name) (fresh-type-variable)) args)) (define env* (append (map cons args parameter-types) env)) (define body* (annotate-exp body env*)) (define return-types (texp-types body*)) (make-texp (list (make-procedure-type parameter-types return-types)) `(lambda ,args ,body*))) (define-matcher (annotate:call (operator args ...) env) (define operator* (annotate-exp operator env)) (define args* (map (lambda (arg) (annotate-exp arg env)) args)) (make-texp (map (lambda (_type) (fresh-type-variable)) (texp-types operator*)) (cons operator* args*))) (define-matcher (annotate:values ('values vals ...) env) (define vals* (map (lambda (val) (annotate-exp val env)) vals)) (define val-types (map single-type vals*)) (make-texp val-types `(values ,@vals*))) (define-matcher (annotate:call-with-values ('call-with-values producer consumer) env) (define producer* (annotate-exp producer env)) (define consumer* (annotate-exp consumer env)) (make-texp (texp-types consumer*) `(call-with-values ,producer* ,consumer*))) (define annotate-exp (compose-matchers annotate:bool annotate:int annotate:float annotate:var annotate:if annotate:let annotate:begin annotate:lambda annotate:values annotate:call-with-values annotate:call)) (define (annotate-exp* exp) (annotate-exp exp (top-level-env))) ;;; ;;; Constraints ;;; (define-record-type (constrain lhs rhs) constraint? (lhs constraint-lhs) (rhs constraint-rhs)) (define (display-constraint constraint port) (format port "#" (constraint-lhs constraint) (constraint-rhs constraint))) (set-record-type-printer! display-constraint) (define-matcher (constrain:other ((? type? _) ...) _) '()) (define-matcher (constrain:if ((? type? types) ...) ('if (? texp? predicate) (? texp? consequent) (? texp? alternate))) (append (list (constrain (list bool-type) (texp-types predicate)) (constrain types (texp-types consequent)) (constrain types (texp-types alternate))) (program-constraints predicate) (program-constraints consequent) (program-constraints alternate))) (define-matcher (constrain:let ((? type? types) ...) ('let (((? symbol? vars) (? texp? vals)) ...) (? texp? body))) (append (program-constraints body) (append-map program-constraints vals))) (define-matcher (constrain:begin ((? type? types) ...) ('begin (? texp? texps) ... (? texp? last))) (append (program-constraints last) (append-map program-constraints texps))) (define-matcher (constrain:for-all-lambda ((? for-all-type? type)) ('lambda ((? symbol? args) ...) (? texp? body))) (cons (pk 'constrain-for-all-lambda (constrain (procedure-type-return-types (for-all-type-type type)) (texp-types body))) (program-constraints body))) (define-matcher (constrain:lambda ((? procedure-type? type)) ('lambda ((? symbol? args) ...) (? texp? body))) (cons (pk 'constrain-lambda (constrain (procedure-type-return-types type) (texp-types body))) (program-constraints body))) (define-matcher (constrain:call ((? type? types) ...) ((? texp? operator) (? texp? operands) ...)) (cons (pk 'constrain-call (constrain (texp-types operator) (list (make-procedure-type (map single-type operands) types)))) (append (program-constraints operator) (append-map program-constraints operands)))) (define-matcher (constrain:values ((? type? types) ...) ('values (? texp? vals) ...)) (append-map program-constraints vals)) (define-matcher (constrain:call-with-values ((? procedure-type? type)) ('call-with-values (? texp? producer) (? texp? consumer))) (define params (procedure-type-parameter-types type)) (cons (pk 'constrain-call-with-values (constrain (texp-types producer) (list (make-procedure-type '() params)))) (append (program-constraints producer) (program-constraints consumer)))) (define %program-constraints (compose-matchers constrain:if constrain:let constrain:begin constrain:for-all-lambda constrain:lambda constrain:values constrain:call-with-values constrain:call constrain:other)) (define (program-constraints texp) (%program-constraints (texp-types texp) (texp-exp texp))) ;;; ;;; Unification ;;; (define %unify-prompt-tag (make-prompt-tag 'unify)) (define (call-with-unify-backtrack thunk failure-handler) (call-with-prompt %unify-prompt-tag thunk (lambda (_k) (failure-handler)))) (define (unify-fail) (pk 'unify-fail) (abort-to-prompt %unify-prompt-tag)) (define (maybe-substitute var other dict) (cond ;; Tautology: matching 2 vars that are the same. ((and (type-variable? other) (eq? var other)) (pk 'tautology var other) dict) ;; Variable has been bound to some other value, recursively follow ;; it and unify. ((assq-ref dict var) => (lambda (type) (pk 'forward var other) (unify type other dict))) ;; Substitute variable for value. (else (pk 'substitute var other) (or (substitute var other dict) (unify-fail))))) (define (constant? obj) (and (not (type-variable? obj)) (not (procedure-type? obj)) (not (list? obj)))) (define-matcher (unify:constants a b dict) (pk 'unify:constants a b) (if (eqv? a b) dict (unify-fail))) (define-matcher (unify:lists (a rest-a ...) (b rest-b ...) dict) (pk 'unify:lists (cons a rest-a) (cons b rest-b)) (let ((dict* (unify a b dict))) (unify rest-a rest-b dict*))) (define-matcher (unify:variable-left (? type-variable? a) b dict) (pk 'unify:variable-left a b) (maybe-substitute a b dict)) (define-matcher (unify:variable-right a (? type-variable? b) dict) (pk 'unify:variable-right a b) (maybe-substitute b a dict)) (define-matcher (unify:procedures (? procedure-type? a) (? procedure-type? b) dict) (pk 'unify:procedures a b) (define dict* (unify (procedure-type-parameter-types a) (procedure-type-parameter-types b) dict)) (pk 'dict* dict*) (unify (procedure-type-return-types a) (procedure-type-return-types b) dict*)) (define-matcher (unify:intersection (? intersection-type? a) (? type? b) dict) (let loop ((types (intersection-type-types a))) (match types (() (unify-fail)) ((t . rest) (call-with-prompt %unify-prompt-tag (lambda () (unify t b dict)) (lambda (_k) (loop rest))))))) (define unify (compose-matchers unify:variable-left unify:variable-right unify:procedures unify:intersection unify:lists unify:constants)) (define (unify-constraints constraints) (call-with-unify-backtrack (lambda () (unify (map constraint-lhs constraints) (map constraint-rhs constraints) '())) (lambda () #f))) ;;; ;;; Type Resolution ;;; (define-matcher (resolve:primitive x dict) (pk 'resolve-primitive) x) (define-matcher (resolve:primitive-type (? primitive-type? type) dict) (pk 'resolve-primitive-type type) type) (define-matcher (resolve:type-variable (? type-variable? var) dict) (pk 'resolve-var var) (let ((type (assq-ref dict var))) (if type (resolve-types type dict) var))) (define-matcher (resolve:procedure-type (? procedure-type? type) dict) (pk 'resolve-proc type) (make-procedure-type (resolve-types (procedure-type-parameter-types type) dict) (resolve-types (procedure-type-return-types type) dict))) (define-matcher (resolve:list (exps ...) dict) (pk 'resolve-list exps) (map (lambda (texp) (resolve-types texp dict)) exps)) (define resolve-types (compose-matchers resolve:primitive-type resolve:type-variable resolve:procedure-type resolve:list resolve:primitive)) (define (infer-types exp) (parameterize ((unique-counter 0)) (pk 'begin-inference) (let* ((texp (pk 'texp (annotate-exp* exp))) (constraints (pk 'constraints (program-constraints texp))) (substitutions (pk 'substitutions (unify-constraints constraints)))) (and substitutions (resolve-types texp substitutions)))))