summaryrefslogtreecommitdiff
path: root/infer2.scm
diff options
context:
space:
mode:
Diffstat (limited to 'infer2.scm')
-rw-r--r--infer2.scm402
1 files changed, 402 insertions, 0 deletions
diff --git a/infer2.scm b/infer2.scm
new file mode 100644
index 0000000..e2ff375
--- /dev/null
+++ b/infer2.scm
@@ -0,0 +1,402 @@
+(use-modules (ice-9 format)
+ (ice-9 match)
+ (srfi srfi-1)
+ (srfi srfi-9)
+ (srfi srfi-9 gnu))
+
+(define-syntax-rule (define-matcher (name rules ...) body ...)
+ (define name
+ (match-lambda*
+ ((rules ...) body ...)
+ (_ 'unmatched))))
+
+(define-syntax-rule (compose-matchers matcher ...)
+ (lambda args
+ (let loop ((matchers (list matcher ...)))
+ (match matchers
+ (() (error "unmatched" args))
+ ((m . rest)
+ (let ((result (apply m args)))
+ (if (eq? result 'unmatched)
+ (loop rest)
+ result)))))))
+
+
+;;;
+;;; Typed expression annotation
+;;;
+
+(define-record-type <primitive-type>
+ (make-primitive-type name)
+ primitive-type?
+ (name primitive-type-name))
+
+(define (display-primitive-type type port)
+ (format port "#<primitive ~a>" (primitive-type-name type)))
+(set-record-type-printer! <primitive-type> display-primitive-type)
+
+(define int-type (make-primitive-type 'int))
+(define float-type (make-primitive-type 'float))
+(define bool-type (make-primitive-type 'bool))
+
+(define-record-type <type-variable>
+ (make-type-variable name)
+ type-variable?
+ (name type-variable-name))
+
+(define (display-type-variable type port)
+ (format port "#<tvar ~a>" (type-variable-name type)))
+(set-record-type-printer! <type-variable> display-type-variable)
+
+(define unique-counter (make-parameter 0))
+
+(define (unique-number)
+ (let ((n (unique-counter)))
+ (unique-counter (+ n 1))
+ n))
+
+(define* (unique-identifier #:optional (prefix 'T))
+ (string->symbol
+ (format #f "~a~a" prefix (unique-number))))
+
+(define (fresh-type-variable)
+ (make-type-variable (unique-identifier)))
+
+(define-record-type <procedure-type>
+ (make-procedure-type arg-types return-type)
+ procedure-type?
+ (arg-types procedure-type-arg-types)
+ (return-type procedure-type-return-type))
+
+(define (display-procedure-type type port)
+ (format port "#<proc-type ~a → ~a>"
+ (procedure-type-arg-types type)
+ (procedure-type-return-type type)))
+(set-record-type-printer! <procedure-type> display-procedure-type)
+
+(define (type? obj)
+ (or (primitive-type? obj)
+ (type-variable? obj)
+ (procedure-type? obj)))
+
+(define (top-level-env)
+ '())
+
+(define (lookup var env)
+ (or (assq-ref env var)
+ (error "unbound variable" var)))
+
+(define (make-texp type exp)
+ `(t ,type ,exp))
+
+(define (texp? obj)
+ (match obj
+ (('t _ _) #t)
+ (_ #f)))
+
+(define (texp-type texp)
+ (match texp
+ (('t type _) type)))
+
+(define (texp-exp texp)
+ (match texp
+ (('t _ exp) exp)))
+
+(define (display-typed-expression texp port)
+ (format port "#<texp ~a ~a>" (texp-type texp) (texp-exp texp)))
+(set-record-type-printer! <typed-expression> display-typed-expression)
+
+(define-matcher (annotate:bool (? boolean? b) env)
+ (make-texp bool-type b))
+
+(define-matcher (annotate:int (and (? number?) (? exact-integer? n)) env)
+ (make-texp int-type n))
+
+(define-matcher (annotate:float (and (? number?) (? inexact? n)) env)
+ (make-texp float-type n))
+
+(define-matcher (annotate:var (? symbol? var) env)
+ (make-texp (lookup var env) var))
+
+(define-matcher (annotate:if ('if predicate consequent alternate) env)
+ (make-texp (fresh-type-variable)
+ `(if ,(annotate-exp predicate env)
+ ,(annotate-exp consequent env)
+ ,(annotate-exp alternate env))))
+
+(define-matcher (annotate:lambda ('lambda ((? symbol? args) ...) body) env)
+ (define arg-types (map (lambda (_name) (fresh-type-variable)) args))
+ (define env* (append (map cons args arg-types) env))
+ (define return-type (fresh-type-variable))
+ (make-texp (make-procedure-type arg-types return-type)
+ `(lambda ,args ,(annotate-exp body env*))))
+
+(define-matcher (annotate:call (operator args ...) env)
+ (make-texp (fresh-type-variable)
+ (cons (annotate-exp operator env)
+ (map (lambda (arg) (annotate-exp arg env)) args))))
+
+(define annotate-exp
+ (compose-matchers annotate:bool
+ annotate:int
+ annotate:float
+ annotate:var
+ annotate:if
+ annotate:lambda
+ annotate:call))
+
+(define (annotate-exp* exp)
+ (parameterize ((unique-counter 0))
+ (annotate-exp exp (top-level-env))))
+
+
+;;;
+;;; Constraints
+;;;
+
+(define-record-type <constraint>
+ (constrain lhs rhs)
+ constraint?
+ (lhs constraint-lhs)
+ (rhs constraint-rhs))
+
+(define (display-constraint constraint port)
+ (format port "#<constraint ~a = ~a>"
+ (constraint-lhs constraint)
+ (constraint-rhs constraint)))
+(set-record-type-printer! <constraint> display-constraint)
+
+(define-matcher (constrain:other (? type? _) _) '())
+
+(define-matcher (constrain:if (? type? type)
+ ('if (? texp? predicate)
+ (? texp? consequent)
+ (? texp? alternate)))
+ (append (list (constrain bool-type (texp-type predicate))
+ (constrain type (texp-type consequent))
+ (constrain type (texp-type alternate)))
+ (program-constraints predicate)
+ (program-constraints consequent)
+ (program-constraints alternate)))
+
+(define-matcher (constrain:lambda (? procedure-type? type)
+ ('lambda ((? symbol? args) ...)
+ (? texp? body)))
+ (cons (constrain (procedure-type-return-type type)
+ (texp-type body))
+ (program-constraints body)))
+
+(define-matcher (constrain:call (? type? type)
+ ((? texp? operator) (? texp? operands) ...))
+ (cons (constrain (texp-type operator)
+ (make-procedure-type (map texp-type operands)
+ type))
+ (append (program-constraints operator)
+ (append-map program-constraints operands))))
+
+(define %program-constraints
+ (compose-matchers constrain:if
+ constrain:lambda
+ constrain:call
+ constrain:other))
+
+(define (program-constraints texp)
+ (%program-constraints (texp-type texp) (texp-exp texp)))
+
+
+;;;
+;;; Unification
+;;;
+
+(define-matcher (occurs:default _ _)
+ #f)
+
+(define-matcher (occurs:variable (? type-variable? a) (? type-variable? b))
+ (eq? a b))
+
+(define-matcher (occurs:procedure (? type-variable? v) (? procedure-type? p))
+ (or (any (lambda (arg-type)
+ (occurs-in? v arg-type))
+ (procedure-type-arg-types p))
+ (occurs-in? v (procedure-type-return-type p))))
+
+(define occurs-in?
+ (compose-matchers occurs:variable
+ occurs:procedure
+ occurs:default))
+
+(define (substitute-term term dict)
+ (match dict
+ (() term)
+ (((from . to) . rest)
+ (if (eq? term from)
+ to
+ (substitute-term term rest)))))
+
+(define (substitute-dict var term dict)
+ (map (match-lambda
+ ((a . b)
+ (cons (if (eq? a var) term a)
+ (if (eq? b var) term b))))
+ dict))
+
+(define (substitute var term dict)
+ (let ((term* (substitute-term term dict)))
+ (and (not (occurs-in? var term*))
+ (alist-cons var term* (substitute-dict var term* dict)))))
+
+(define (maybe-substitute var-first terms)
+ (lambda (dict succeed fail)
+ (match var-first
+ ((var . rest-vars)
+ (match terms
+ ((term . rest-terms)
+ (cond
+ ;; Tautology: matching 2 vars that are the same.
+ ((and (type-variable? term) (eq? var term))
+ (succeed dict fail rest-vars rest-terms))
+ ;; Variable has been bound to some other value, recursively
+ ;; follow it and unify.
+ ((assq-ref dict var) =>
+ (lambda (type)
+ ((unify:dispatch (cons type rest-vars) terms)
+ dict succeed fail)))
+ ;; Substitute variable for value.
+ (else
+ (let ((dict* (substitute var term dict)))
+ (if dict*
+ (succeed dict* fail rest-vars rest-terms)
+ (fail)))))))))))
+
+(define (constant? obj)
+ (and (not (type-variable? obj))
+ (not (procedure-type? obj))
+ (not (list? obj))))
+
+(define-matcher (unify:fail a b)
+ (define (fail-unifier dict succeed fail)
+ (fail))
+ fail-unifier)
+
+(define-matcher (unify:constants ((? constant? a) . rest-a) ((? constant? b) . rest-b))
+ (define (constant-unifier dict succeed fail)
+ (if (eqv? a b)
+ (begin
+ (succeed dict fail rest-a rest-b))
+ (fail)))
+ constant-unifier)
+
+(define-matcher (unify:lists (a . rest-a) (b . rest-b))
+ (define (list-unifier dict succeed fail)
+ ((unify:dispatch a b)
+ dict
+ (lambda (dict* fail* _null-a _null-b)
+ (succeed dict* fail* rest-a rest-b))
+ fail))
+ list-unifier)
+
+(define-matcher (unify:variable-left (and ((? type-variable? _) . _) a) b)
+ (maybe-substitute a b))
+
+(define-matcher (unify:variable-right a (and ((? type-variable? _) . _) b))
+ (maybe-substitute b a))
+
+(define-matcher (unify:procedures ((? procedure-type? a) . rest-a)
+ ((? procedure-type? b) . rest-b))
+ (define (procedure-unifier dict succeed fail)
+ ((unify:dispatch (cons (procedure-type-return-type a)
+ (procedure-type-arg-types a))
+ (cons (procedure-type-return-type b)
+ (procedure-type-arg-types b)))
+ dict
+ (lambda (dict* fail* _null-a _null-b)
+ (succeed dict* fail* rest-a rest-b))
+ fail))
+ procedure-unifier)
+
+(define %unify:dispatch
+ (compose-matchers unify:constants
+ unify:variable-left
+ unify:variable-right
+ unify:procedures
+ unify:lists
+ unify:fail))
+
+(define (unify:dispatch a b)
+ (define (dispatcher dict succeed fail)
+ (if (and (null? a) (null? b))
+ (succeed dict fail a b)
+ ((%unify:dispatch a b)
+ dict
+ (lambda (dict* fail* rest-a rest-b)
+ ((unify:dispatch rest-a rest-b)
+ dict* succeed fail*))
+ fail)))
+ dispatcher)
+
+(define (%unify a b dict succeed)
+ ((unify:dispatch (list a) (list b))
+ dict
+ (lambda (dict fail rest-a rest-b)
+ (or (and (null? rest-a) (null? rest-b) (succeed dict))
+ (fail)))
+ (lambda ()
+ #f)))
+
+(define (unify a b)
+ (%unify a b '() identity))
+
+(define (unify-constraints constraints)
+ (unify (map constraint-lhs constraints)
+ (map constraint-rhs constraints)))
+
+
+;;;
+;;; Type Resolution
+;;;
+
+(define (primitive? x)
+ (or (number? x)
+ (boolean? x)
+ (symbol? x)))
+
+(define-matcher (resolve:primitive (? primitive? x) dict)
+ x)
+
+(define-matcher (resolve:primitive-type (? primitive-type? type) dict)
+ type)
+
+(define-matcher (resolve:type-variable (? type-variable? var) dict)
+ (let ((type (assq-ref dict var)))
+ (if (or (not type) (type-variable? type))
+ (error "cannot determine type" var)
+ type)))
+
+(define-matcher (resolve:procedure-type (? procedure-type? type) dict)
+ (make-procedure-type (resolve-types (procedure-type-arg-types type) dict)
+ (resolve-types (procedure-type-return-type type) dict)))
+
+(define-matcher (resolve:list (exps ...) dict)
+ (map (lambda (texp) (resolve-types texp dict)) exps))
+
+(define resolve-types
+ (compose-matchers resolve:primitive
+ resolve:primitive-type
+ resolve:type-variable
+ resolve:procedure-type
+ resolve:list))
+
+(define (infer-types exp)
+ (let* ((texp (annotate-exp* exp))
+ (constraints (program-constraints texp))
+ (substitutions (unify-constraints constraints)))
+ (unless substitutions
+ (error "type mismatch" texp))
+ (resolve-types texp substitutions)))
+
+(infer-types #t)
+(infer-types 6)
+(infer-types 6.5)
+(infer-types '(if #t 1 2))
+(infer-types '((lambda (x) x) 1))
+(infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x)))