diff options
Diffstat (limited to 'infer2.scm')
-rw-r--r-- | infer2.scm | 278 |
1 files changed, 124 insertions, 154 deletions
@@ -63,15 +63,15 @@ (make-type-variable (unique-identifier))) (define-record-type <procedure-type> - (make-procedure-type arg-types return-type) + (make-procedure-type parameter-types return-types) procedure-type? - (arg-types procedure-type-arg-types) - (return-type procedure-type-return-type)) + (parameter-types procedure-type-parameter-types) + (return-types procedure-type-return-types)) (define (display-procedure-type type port) (format port "#<proc-type ~a → ~a>" - (procedure-type-arg-types type) - (procedure-type-return-type type))) + (procedure-type-parameter-types type) + (procedure-type-return-types type))) (set-record-type-printer! <procedure-type> display-procedure-type) (define (type? obj) @@ -86,54 +86,58 @@ (or (assq-ref env var) (error "unbound variable" var))) -(define (make-texp type exp) - `(t ,type ,exp)) +(define (make-texp types exp) + `(t ,types ,exp)) (define (texp? obj) (match obj (('t _ _) #t) (_ #f))) -(define (texp-type texp) +(define (texp-types texp) (match texp - (('t type _) type))) + (('t types _) types))) (define (texp-exp texp) (match texp (('t _ exp) exp))) (define (display-typed-expression texp port) - (format port "#<texp ~a ~a>" (texp-type texp) (texp-exp texp))) + (format port "#<texp ~a ~a>" (texp-types texp) (texp-exp texp))) (set-record-type-printer! <typed-expression> display-typed-expression) (define-matcher (annotate:bool (? boolean? b) env) - (make-texp bool-type b)) + (make-texp (list bool-type) b)) (define-matcher (annotate:int (and (? number?) (? exact-integer? n)) env) - (make-texp int-type n)) + (make-texp (list int-type) n)) (define-matcher (annotate:float (and (? number?) (? inexact? n)) env) - (make-texp float-type n)) + (make-texp (list float-type) n)) (define-matcher (annotate:var (? symbol? var) env) - (make-texp (lookup var env) var)) + (make-texp (list (lookup var env)) var)) (define-matcher (annotate:if ('if predicate consequent alternate) env) - (make-texp (fresh-type-variable) + (make-texp (list (fresh-type-variable)) `(if ,(annotate-exp predicate env) ,(annotate-exp consequent env) ,(annotate-exp alternate env)))) (define-matcher (annotate:lambda ('lambda ((? symbol? args) ...) body) env) - (define arg-types (map (lambda (_name) (fresh-type-variable)) args)) - (define env* (append (map cons args arg-types) env)) - (define return-type (fresh-type-variable)) - (make-texp (make-procedure-type arg-types return-type) - `(lambda ,args ,(annotate-exp body env*)))) + (define parameter-types (map (lambda (_name) (fresh-type-variable)) args)) + (define env* (append (map cons args parameter-types) env)) + (define body* (annotate-exp body env*)) + (define return-types + (map (lambda (_type) (fresh-type-variable)) (texp-types body*))) + (make-texp (list (make-procedure-type parameter-types return-types)) + `(lambda ,args ,body*))) (define-matcher (annotate:call (operator args ...) env) - (make-texp (fresh-type-variable) - (cons (annotate-exp operator env) + (define operator* (annotate-exp operator env)) + (make-texp (map (lambda (_type) (fresh-type-variable)) + (texp-types operator*)) + (cons operator* (map (lambda (arg) (annotate-exp arg env)) args)))) (define annotate-exp @@ -166,31 +170,31 @@ (constraint-rhs constraint))) (set-record-type-printer! <constraint> display-constraint) -(define-matcher (constrain:other (? type? _) _) '()) +(define-matcher (constrain:other ((? type? _) ...) _) '()) -(define-matcher (constrain:if (? type? type) +(define-matcher (constrain:if ((? type? types) ...) ('if (? texp? predicate) (? texp? consequent) (? texp? alternate))) - (append (list (constrain bool-type (texp-type predicate)) - (constrain type (texp-type consequent)) - (constrain type (texp-type alternate))) + (append (list (constrain (list bool-type) (texp-types predicate)) + (constrain types (texp-types consequent)) + (constrain types (texp-types alternate))) (program-constraints predicate) (program-constraints consequent) (program-constraints alternate))) -(define-matcher (constrain:lambda (? procedure-type? type) +(define-matcher (constrain:lambda ((? procedure-type? type)) ('lambda ((? symbol? args) ...) (? texp? body))) - (cons (constrain (procedure-type-return-type type) - (texp-type body)) + (cons (constrain (procedure-type-return-types type) + (texp-types body)) (program-constraints body))) -(define-matcher (constrain:call (? type? type) +(define-matcher (constrain:call ((? type? types) ...) ((? texp? operator) (? texp? operands) ...)) - (cons (constrain (texp-type operator) - (make-procedure-type (map texp-type operands) - type)) + (cons (constrain (texp-types operator) + (list (make-procedure-type (map texp-type operands) + types))) (append (program-constraints operator) (append-map program-constraints operands)))) @@ -205,26 +209,34 @@ ;;; -;;; Unification +;;; Occurs check ;;; (define-matcher (occurs:default _ _) #f) +(define-matcher (occurs:list (a rest-a ...) (b rest-b ...)) + (or (occurs-in? a b) + (occurs-in? rest-a rest-b))) + (define-matcher (occurs:variable (? type-variable? a) (? type-variable? b)) (eq? a b)) (define-matcher (occurs:procedure (? type-variable? v) (? procedure-type? p)) - (or (any (lambda (arg-type) - (occurs-in? v arg-type)) - (procedure-type-arg-types p)) - (occurs-in? v (procedure-type-return-type p)))) + (or (occurs-in? v (procedure-type-parameter-types p)) + (occurs-in? v (procedure-type-return-types p)))) (define occurs-in? - (compose-matchers occurs:variable + (compose-matchers occurs:list + occurs:variable occurs:procedure occurs:default)) + +;;; +;;; Unification +;;; + (define (substitute-term term dict) (match dict (() term) @@ -240,127 +252,86 @@ (if (eq? b var) term b)))) dict)) -(define (substitute var term dict) - (let ((term* (substitute-term term dict))) - (and (not (occurs-in? var term*)) - (alist-cons var term* (substitute-dict var term* dict))))) - -(define (maybe-substitute var-first terms) - (lambda (dict succeed fail) - (match var-first - ((var . rest-vars) - (match terms - ((term . rest-terms) - (cond - ;; Tautology: matching 2 vars that are the same. - ((and (type-variable? term) (eq? var term)) - (succeed dict fail rest-vars rest-terms)) - ;; Variable has been bound to some other value, recursively - ;; follow it and unify. - ((assq-ref dict var) => - (lambda (type) - ((unify:dispatch (cons type rest-vars) terms) - dict succeed fail))) - ;; Substitute variable for value. - (else - (let ((dict* (substitute var term dict))) - (if dict* - (succeed dict* fail rest-vars rest-terms) - (fail))))))))))) +(define (substitute var other dict) + (let ((other* (substitute-term other dict))) + (and (not (occurs-in? var other*)) + (alist-cons var other* (substitute-dict var other* dict))))) + +(define %unify-prompt-tag (make-prompt-tag 'unify)) + +(define (unify-fail) + (pk 'unify-fail) + (abort-to-prompt %unify-prompt-tag)) + +(define (maybe-substitute var other dict) + (cond + ;; Tautology: matching 2 vars that are the same. + ((and (type-variable? other) (eq? var other)) + dict) + ;; Variable has been bound to some other value, recursively follow + ;; it and unify. + ((assq-ref dict var) => + (lambda (type) + (unify type other dict))) + ;; Substitute variable for value. + (else + (or (substitute var other dict) + (unify-fail))))) (define (constant? obj) (and (not (type-variable? obj)) (not (procedure-type? obj)) (not (list? obj)))) -(define-matcher (unify:fail a b) - (define (fail-unifier dict succeed fail) - (fail)) - fail-unifier) - -(define-matcher (unify:constants ((? constant? a) . rest-a) ((? constant? b) . rest-b)) - (define (constant-unifier dict succeed fail) - (if (eqv? a b) - (begin - (succeed dict fail rest-a rest-b)) - (fail))) - constant-unifier) - -(define-matcher (unify:lists (a . rest-a) (b . rest-b)) - (define (list-unifier dict succeed fail) - ((unify:dispatch a b) - dict - (lambda (dict* fail* _null-a _null-b) - (succeed dict* fail* rest-a rest-b)) - fail)) - list-unifier) - -(define-matcher (unify:variable-left (and ((? type-variable? _) . _) a) b) - (maybe-substitute a b)) - -(define-matcher (unify:variable-right a (and ((? type-variable? _) . _) b)) - (maybe-substitute b a)) - -(define-matcher (unify:procedures ((? procedure-type? a) . rest-a) - ((? procedure-type? b) . rest-b)) - (define (procedure-unifier dict succeed fail) - ((unify:dispatch (cons (procedure-type-return-type a) - (procedure-type-arg-types a)) - (cons (procedure-type-return-type b) - (procedure-type-arg-types b))) - dict - (lambda (dict* fail* _null-a _null-b) - (succeed dict* fail* rest-a rest-b)) - fail)) - procedure-unifier) - -(define %unify:dispatch - (compose-matchers unify:constants - unify:variable-left +(define-matcher (unify:constants a b dict) + (pk 'unify:constants a b) + (if (eqv? a b) + dict + (unify-fail))) + +(define-matcher (unify:lists (a rest-a ...) (b rest-b ...) dict) + (pk 'unify:lists (cons a rest-a) (cons b rest-b)) + (let ((dict* (unify a b dict))) + (unify rest-a rest-b dict*))) + +(define-matcher (unify:variable-left (? type-variable? a) b dict) + (pk 'unify:variable-left a b) + (maybe-substitute a b dict)) + +(define-matcher (unify:variable-right a (? type-variable? b) dict) + (pk 'unify:variable-right a b) + (maybe-substitute b a dict)) + +(define-matcher (unify:procedures (? procedure-type? a) (? procedure-type? b) dict) + (pk 'unify:procedures a b) + (define dict* (unify (procedure-type-parameter-types a) + (procedure-type-parameter-types b) + dict)) + (unify (procedure-type-return-types a) + (procedure-type-return-types b) + dict*)) + +(define unify + (compose-matchers unify:variable-left unify:variable-right unify:procedures unify:lists - unify:fail)) - -(define (unify:dispatch a b) - (define (dispatcher dict succeed fail) - (if (and (null? a) (null? b)) - (succeed dict fail a b) - ((%unify:dispatch a b) - dict - (lambda (dict* fail* rest-a rest-b) - ((unify:dispatch rest-a rest-b) - dict* succeed fail*)) - fail))) - dispatcher) - -(define (%unify a b dict succeed) - ((unify:dispatch (list a) (list b)) - dict - (lambda (dict fail rest-a rest-b) - (or (and (null? rest-a) (null? rest-b) (succeed dict)) - (fail))) - (lambda () - #f))) - -(define (unify a b) - (%unify a b '() identity)) + unify:constants)) (define (unify-constraints constraints) - (unify (map constraint-lhs constraints) - (map constraint-rhs constraints))) + (call-with-prompt %unify-prompt-tag + (lambda () + (unify (map constraint-lhs constraints) + (map constraint-rhs constraints) + '())) + (lambda (_k) #f))) ;;; ;;; Type Resolution ;;; -(define (primitive? x) - (or (number? x) - (boolean? x) - (symbol? x))) - -(define-matcher (resolve:primitive (? primitive? x) dict) +(define-matcher (resolve:primitive x dict) x) (define-matcher (resolve:primitive-type (? primitive-type? type) dict) @@ -373,30 +344,29 @@ type))) (define-matcher (resolve:procedure-type (? procedure-type? type) dict) - (make-procedure-type (resolve-types (procedure-type-arg-types type) dict) - (resolve-types (procedure-type-return-type type) dict))) + (make-procedure-type (resolve-types (procedure-type-parameter-types type) dict) + (resolve-types (procedure-type-return-types type) dict))) (define-matcher (resolve:list (exps ...) dict) (map (lambda (texp) (resolve-types texp dict)) exps)) (define resolve-types - (compose-matchers resolve:primitive - resolve:primitive-type + (compose-matchers resolve:primitive-type resolve:type-variable resolve:procedure-type - resolve:list)) + resolve:list + resolve:primitive)) (define (infer-types exp) (let* ((texp (annotate-exp* exp)) - (constraints (program-constraints texp)) + (constraints (pk 'constraints (program-constraints texp))) (substitutions (unify-constraints constraints))) - (unless substitutions - (error "type mismatch" texp)) - (resolve-types texp substitutions))) + (and substitutions + (resolve-types texp substitutions)))) (infer-types #t) (infer-types 6) (infer-types 6.5) (infer-types '(if #t 1 2)) (infer-types '((lambda (x) x) 1)) -(infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x))) +;;(infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x))) |