diff options
-rw-r--r-- | infer2.scm | 402 |
1 files changed, 402 insertions, 0 deletions
diff --git a/infer2.scm b/infer2.scm new file mode 100644 index 0000000..e2ff375 --- /dev/null +++ b/infer2.scm @@ -0,0 +1,402 @@ +(use-modules (ice-9 format) + (ice-9 match) + (srfi srfi-1) + (srfi srfi-9) + (srfi srfi-9 gnu)) + +(define-syntax-rule (define-matcher (name rules ...) body ...) + (define name + (match-lambda* + ((rules ...) body ...) + (_ 'unmatched)))) + +(define-syntax-rule (compose-matchers matcher ...) + (lambda args + (let loop ((matchers (list matcher ...))) + (match matchers + (() (error "unmatched" args)) + ((m . rest) + (let ((result (apply m args))) + (if (eq? result 'unmatched) + (loop rest) + result))))))) + + +;;; +;;; Typed expression annotation +;;; + +(define-record-type <primitive-type> + (make-primitive-type name) + primitive-type? + (name primitive-type-name)) + +(define (display-primitive-type type port) + (format port "#<primitive ~a>" (primitive-type-name type))) +(set-record-type-printer! <primitive-type> display-primitive-type) + +(define int-type (make-primitive-type 'int)) +(define float-type (make-primitive-type 'float)) +(define bool-type (make-primitive-type 'bool)) + +(define-record-type <type-variable> + (make-type-variable name) + type-variable? + (name type-variable-name)) + +(define (display-type-variable type port) + (format port "#<tvar ~a>" (type-variable-name type))) +(set-record-type-printer! <type-variable> display-type-variable) + +(define unique-counter (make-parameter 0)) + +(define (unique-number) + (let ((n (unique-counter))) + (unique-counter (+ n 1)) + n)) + +(define* (unique-identifier #:optional (prefix 'T)) + (string->symbol + (format #f "~a~a" prefix (unique-number)))) + +(define (fresh-type-variable) + (make-type-variable (unique-identifier))) + +(define-record-type <procedure-type> + (make-procedure-type arg-types return-type) + procedure-type? + (arg-types procedure-type-arg-types) + (return-type procedure-type-return-type)) + +(define (display-procedure-type type port) + (format port "#<proc-type ~a → ~a>" + (procedure-type-arg-types type) + (procedure-type-return-type type))) +(set-record-type-printer! <procedure-type> display-procedure-type) + +(define (type? obj) + (or (primitive-type? obj) + (type-variable? obj) + (procedure-type? obj))) + +(define (top-level-env) + '()) + +(define (lookup var env) + (or (assq-ref env var) + (error "unbound variable" var))) + +(define (make-texp type exp) + `(t ,type ,exp)) + +(define (texp? obj) + (match obj + (('t _ _) #t) + (_ #f))) + +(define (texp-type texp) + (match texp + (('t type _) type))) + +(define (texp-exp texp) + (match texp + (('t _ exp) exp))) + +(define (display-typed-expression texp port) + (format port "#<texp ~a ~a>" (texp-type texp) (texp-exp texp))) +(set-record-type-printer! <typed-expression> display-typed-expression) + +(define-matcher (annotate:bool (? boolean? b) env) + (make-texp bool-type b)) + +(define-matcher (annotate:int (and (? number?) (? exact-integer? n)) env) + (make-texp int-type n)) + +(define-matcher (annotate:float (and (? number?) (? inexact? n)) env) + (make-texp float-type n)) + +(define-matcher (annotate:var (? symbol? var) env) + (make-texp (lookup var env) var)) + +(define-matcher (annotate:if ('if predicate consequent alternate) env) + (make-texp (fresh-type-variable) + `(if ,(annotate-exp predicate env) + ,(annotate-exp consequent env) + ,(annotate-exp alternate env)))) + +(define-matcher (annotate:lambda ('lambda ((? symbol? args) ...) body) env) + (define arg-types (map (lambda (_name) (fresh-type-variable)) args)) + (define env* (append (map cons args arg-types) env)) + (define return-type (fresh-type-variable)) + (make-texp (make-procedure-type arg-types return-type) + `(lambda ,args ,(annotate-exp body env*)))) + +(define-matcher (annotate:call (operator args ...) env) + (make-texp (fresh-type-variable) + (cons (annotate-exp operator env) + (map (lambda (arg) (annotate-exp arg env)) args)))) + +(define annotate-exp + (compose-matchers annotate:bool + annotate:int + annotate:float + annotate:var + annotate:if + annotate:lambda + annotate:call)) + +(define (annotate-exp* exp) + (parameterize ((unique-counter 0)) + (annotate-exp exp (top-level-env)))) + + +;;; +;;; Constraints +;;; + +(define-record-type <constraint> + (constrain lhs rhs) + constraint? + (lhs constraint-lhs) + (rhs constraint-rhs)) + +(define (display-constraint constraint port) + (format port "#<constraint ~a = ~a>" + (constraint-lhs constraint) + (constraint-rhs constraint))) +(set-record-type-printer! <constraint> display-constraint) + +(define-matcher (constrain:other (? type? _) _) '()) + +(define-matcher (constrain:if (? type? type) + ('if (? texp? predicate) + (? texp? consequent) + (? texp? alternate))) + (append (list (constrain bool-type (texp-type predicate)) + (constrain type (texp-type consequent)) + (constrain type (texp-type alternate))) + (program-constraints predicate) + (program-constraints consequent) + (program-constraints alternate))) + +(define-matcher (constrain:lambda (? procedure-type? type) + ('lambda ((? symbol? args) ...) + (? texp? body))) + (cons (constrain (procedure-type-return-type type) + (texp-type body)) + (program-constraints body))) + +(define-matcher (constrain:call (? type? type) + ((? texp? operator) (? texp? operands) ...)) + (cons (constrain (texp-type operator) + (make-procedure-type (map texp-type operands) + type)) + (append (program-constraints operator) + (append-map program-constraints operands)))) + +(define %program-constraints + (compose-matchers constrain:if + constrain:lambda + constrain:call + constrain:other)) + +(define (program-constraints texp) + (%program-constraints (texp-type texp) (texp-exp texp))) + + +;;; +;;; Unification +;;; + +(define-matcher (occurs:default _ _) + #f) + +(define-matcher (occurs:variable (? type-variable? a) (? type-variable? b)) + (eq? a b)) + +(define-matcher (occurs:procedure (? type-variable? v) (? procedure-type? p)) + (or (any (lambda (arg-type) + (occurs-in? v arg-type)) + (procedure-type-arg-types p)) + (occurs-in? v (procedure-type-return-type p)))) + +(define occurs-in? + (compose-matchers occurs:variable + occurs:procedure + occurs:default)) + +(define (substitute-term term dict) + (match dict + (() term) + (((from . to) . rest) + (if (eq? term from) + to + (substitute-term term rest))))) + +(define (substitute-dict var term dict) + (map (match-lambda + ((a . b) + (cons (if (eq? a var) term a) + (if (eq? b var) term b)))) + dict)) + +(define (substitute var term dict) + (let ((term* (substitute-term term dict))) + (and (not (occurs-in? var term*)) + (alist-cons var term* (substitute-dict var term* dict))))) + +(define (maybe-substitute var-first terms) + (lambda (dict succeed fail) + (match var-first + ((var . rest-vars) + (match terms + ((term . rest-terms) + (cond + ;; Tautology: matching 2 vars that are the same. + ((and (type-variable? term) (eq? var term)) + (succeed dict fail rest-vars rest-terms)) + ;; Variable has been bound to some other value, recursively + ;; follow it and unify. + ((assq-ref dict var) => + (lambda (type) + ((unify:dispatch (cons type rest-vars) terms) + dict succeed fail))) + ;; Substitute variable for value. + (else + (let ((dict* (substitute var term dict))) + (if dict* + (succeed dict* fail rest-vars rest-terms) + (fail))))))))))) + +(define (constant? obj) + (and (not (type-variable? obj)) + (not (procedure-type? obj)) + (not (list? obj)))) + +(define-matcher (unify:fail a b) + (define (fail-unifier dict succeed fail) + (fail)) + fail-unifier) + +(define-matcher (unify:constants ((? constant? a) . rest-a) ((? constant? b) . rest-b)) + (define (constant-unifier dict succeed fail) + (if (eqv? a b) + (begin + (succeed dict fail rest-a rest-b)) + (fail))) + constant-unifier) + +(define-matcher (unify:lists (a . rest-a) (b . rest-b)) + (define (list-unifier dict succeed fail) + ((unify:dispatch a b) + dict + (lambda (dict* fail* _null-a _null-b) + (succeed dict* fail* rest-a rest-b)) + fail)) + list-unifier) + +(define-matcher (unify:variable-left (and ((? type-variable? _) . _) a) b) + (maybe-substitute a b)) + +(define-matcher (unify:variable-right a (and ((? type-variable? _) . _) b)) + (maybe-substitute b a)) + +(define-matcher (unify:procedures ((? procedure-type? a) . rest-a) + ((? procedure-type? b) . rest-b)) + (define (procedure-unifier dict succeed fail) + ((unify:dispatch (cons (procedure-type-return-type a) + (procedure-type-arg-types a)) + (cons (procedure-type-return-type b) + (procedure-type-arg-types b))) + dict + (lambda (dict* fail* _null-a _null-b) + (succeed dict* fail* rest-a rest-b)) + fail)) + procedure-unifier) + +(define %unify:dispatch + (compose-matchers unify:constants + unify:variable-left + unify:variable-right + unify:procedures + unify:lists + unify:fail)) + +(define (unify:dispatch a b) + (define (dispatcher dict succeed fail) + (if (and (null? a) (null? b)) + (succeed dict fail a b) + ((%unify:dispatch a b) + dict + (lambda (dict* fail* rest-a rest-b) + ((unify:dispatch rest-a rest-b) + dict* succeed fail*)) + fail))) + dispatcher) + +(define (%unify a b dict succeed) + ((unify:dispatch (list a) (list b)) + dict + (lambda (dict fail rest-a rest-b) + (or (and (null? rest-a) (null? rest-b) (succeed dict)) + (fail))) + (lambda () + #f))) + +(define (unify a b) + (%unify a b '() identity)) + +(define (unify-constraints constraints) + (unify (map constraint-lhs constraints) + (map constraint-rhs constraints))) + + +;;; +;;; Type Resolution +;;; + +(define (primitive? x) + (or (number? x) + (boolean? x) + (symbol? x))) + +(define-matcher (resolve:primitive (? primitive? x) dict) + x) + +(define-matcher (resolve:primitive-type (? primitive-type? type) dict) + type) + +(define-matcher (resolve:type-variable (? type-variable? var) dict) + (let ((type (assq-ref dict var))) + (if (or (not type) (type-variable? type)) + (error "cannot determine type" var) + type))) + +(define-matcher (resolve:procedure-type (? procedure-type? type) dict) + (make-procedure-type (resolve-types (procedure-type-arg-types type) dict) + (resolve-types (procedure-type-return-type type) dict))) + +(define-matcher (resolve:list (exps ...) dict) + (map (lambda (texp) (resolve-types texp dict)) exps)) + +(define resolve-types + (compose-matchers resolve:primitive + resolve:primitive-type + resolve:type-variable + resolve:procedure-type + resolve:list)) + +(define (infer-types exp) + (let* ((texp (annotate-exp* exp)) + (constraints (program-constraints texp)) + (substitutions (unify-constraints constraints))) + (unless substitutions + (error "type mismatch" texp)) + (resolve-types texp substitutions))) + +(infer-types #t) +(infer-types 6) +(infer-types 6.5) +(infer-types '(if #t 1 2)) +(infer-types '((lambda (x) x) 1)) +(infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x))) |