summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--infer2.scm278
1 files changed, 124 insertions, 154 deletions
diff --git a/infer2.scm b/infer2.scm
index e2ff375..4097530 100644
--- a/infer2.scm
+++ b/infer2.scm
@@ -63,15 +63,15 @@
(make-type-variable (unique-identifier)))
(define-record-type <procedure-type>
- (make-procedure-type arg-types return-type)
+ (make-procedure-type parameter-types return-types)
procedure-type?
- (arg-types procedure-type-arg-types)
- (return-type procedure-type-return-type))
+ (parameter-types procedure-type-parameter-types)
+ (return-types procedure-type-return-types))
(define (display-procedure-type type port)
(format port "#<proc-type ~a → ~a>"
- (procedure-type-arg-types type)
- (procedure-type-return-type type)))
+ (procedure-type-parameter-types type)
+ (procedure-type-return-types type)))
(set-record-type-printer! <procedure-type> display-procedure-type)
(define (type? obj)
@@ -86,54 +86,58 @@
(or (assq-ref env var)
(error "unbound variable" var)))
-(define (make-texp type exp)
- `(t ,type ,exp))
+(define (make-texp types exp)
+ `(t ,types ,exp))
(define (texp? obj)
(match obj
(('t _ _) #t)
(_ #f)))
-(define (texp-type texp)
+(define (texp-types texp)
(match texp
- (('t type _) type)))
+ (('t types _) types)))
(define (texp-exp texp)
(match texp
(('t _ exp) exp)))
(define (display-typed-expression texp port)
- (format port "#<texp ~a ~a>" (texp-type texp) (texp-exp texp)))
+ (format port "#<texp ~a ~a>" (texp-types texp) (texp-exp texp)))
(set-record-type-printer! <typed-expression> display-typed-expression)
(define-matcher (annotate:bool (? boolean? b) env)
- (make-texp bool-type b))
+ (make-texp (list bool-type) b))
(define-matcher (annotate:int (and (? number?) (? exact-integer? n)) env)
- (make-texp int-type n))
+ (make-texp (list int-type) n))
(define-matcher (annotate:float (and (? number?) (? inexact? n)) env)
- (make-texp float-type n))
+ (make-texp (list float-type) n))
(define-matcher (annotate:var (? symbol? var) env)
- (make-texp (lookup var env) var))
+ (make-texp (list (lookup var env)) var))
(define-matcher (annotate:if ('if predicate consequent alternate) env)
- (make-texp (fresh-type-variable)
+ (make-texp (list (fresh-type-variable))
`(if ,(annotate-exp predicate env)
,(annotate-exp consequent env)
,(annotate-exp alternate env))))
(define-matcher (annotate:lambda ('lambda ((? symbol? args) ...) body) env)
- (define arg-types (map (lambda (_name) (fresh-type-variable)) args))
- (define env* (append (map cons args arg-types) env))
- (define return-type (fresh-type-variable))
- (make-texp (make-procedure-type arg-types return-type)
- `(lambda ,args ,(annotate-exp body env*))))
+ (define parameter-types (map (lambda (_name) (fresh-type-variable)) args))
+ (define env* (append (map cons args parameter-types) env))
+ (define body* (annotate-exp body env*))
+ (define return-types
+ (map (lambda (_type) (fresh-type-variable)) (texp-types body*)))
+ (make-texp (list (make-procedure-type parameter-types return-types))
+ `(lambda ,args ,body*)))
(define-matcher (annotate:call (operator args ...) env)
- (make-texp (fresh-type-variable)
- (cons (annotate-exp operator env)
+ (define operator* (annotate-exp operator env))
+ (make-texp (map (lambda (_type) (fresh-type-variable))
+ (texp-types operator*))
+ (cons operator*
(map (lambda (arg) (annotate-exp arg env)) args))))
(define annotate-exp
@@ -166,31 +170,31 @@
(constraint-rhs constraint)))
(set-record-type-printer! <constraint> display-constraint)
-(define-matcher (constrain:other (? type? _) _) '())
+(define-matcher (constrain:other ((? type? _) ...) _) '())
-(define-matcher (constrain:if (? type? type)
+(define-matcher (constrain:if ((? type? types) ...)
('if (? texp? predicate)
(? texp? consequent)
(? texp? alternate)))
- (append (list (constrain bool-type (texp-type predicate))
- (constrain type (texp-type consequent))
- (constrain type (texp-type alternate)))
+ (append (list (constrain (list bool-type) (texp-types predicate))
+ (constrain types (texp-types consequent))
+ (constrain types (texp-types alternate)))
(program-constraints predicate)
(program-constraints consequent)
(program-constraints alternate)))
-(define-matcher (constrain:lambda (? procedure-type? type)
+(define-matcher (constrain:lambda ((? procedure-type? type))
('lambda ((? symbol? args) ...)
(? texp? body)))
- (cons (constrain (procedure-type-return-type type)
- (texp-type body))
+ (cons (constrain (procedure-type-return-types type)
+ (texp-types body))
(program-constraints body)))
-(define-matcher (constrain:call (? type? type)
+(define-matcher (constrain:call ((? type? types) ...)
((? texp? operator) (? texp? operands) ...))
- (cons (constrain (texp-type operator)
- (make-procedure-type (map texp-type operands)
- type))
+ (cons (constrain (texp-types operator)
+ (list (make-procedure-type (map texp-type operands)
+ types)))
(append (program-constraints operator)
(append-map program-constraints operands))))
@@ -205,26 +209,34 @@
;;;
-;;; Unification
+;;; Occurs check
;;;
(define-matcher (occurs:default _ _)
#f)
+(define-matcher (occurs:list (a rest-a ...) (b rest-b ...))
+ (or (occurs-in? a b)
+ (occurs-in? rest-a rest-b)))
+
(define-matcher (occurs:variable (? type-variable? a) (? type-variable? b))
(eq? a b))
(define-matcher (occurs:procedure (? type-variable? v) (? procedure-type? p))
- (or (any (lambda (arg-type)
- (occurs-in? v arg-type))
- (procedure-type-arg-types p))
- (occurs-in? v (procedure-type-return-type p))))
+ (or (occurs-in? v (procedure-type-parameter-types p))
+ (occurs-in? v (procedure-type-return-types p))))
(define occurs-in?
- (compose-matchers occurs:variable
+ (compose-matchers occurs:list
+ occurs:variable
occurs:procedure
occurs:default))
+
+;;;
+;;; Unification
+;;;
+
(define (substitute-term term dict)
(match dict
(() term)
@@ -240,127 +252,86 @@
(if (eq? b var) term b))))
dict))
-(define (substitute var term dict)
- (let ((term* (substitute-term term dict)))
- (and (not (occurs-in? var term*))
- (alist-cons var term* (substitute-dict var term* dict)))))
-
-(define (maybe-substitute var-first terms)
- (lambda (dict succeed fail)
- (match var-first
- ((var . rest-vars)
- (match terms
- ((term . rest-terms)
- (cond
- ;; Tautology: matching 2 vars that are the same.
- ((and (type-variable? term) (eq? var term))
- (succeed dict fail rest-vars rest-terms))
- ;; Variable has been bound to some other value, recursively
- ;; follow it and unify.
- ((assq-ref dict var) =>
- (lambda (type)
- ((unify:dispatch (cons type rest-vars) terms)
- dict succeed fail)))
- ;; Substitute variable for value.
- (else
- (let ((dict* (substitute var term dict)))
- (if dict*
- (succeed dict* fail rest-vars rest-terms)
- (fail)))))))))))
+(define (substitute var other dict)
+ (let ((other* (substitute-term other dict)))
+ (and (not (occurs-in? var other*))
+ (alist-cons var other* (substitute-dict var other* dict)))))
+
+(define %unify-prompt-tag (make-prompt-tag 'unify))
+
+(define (unify-fail)
+ (pk 'unify-fail)
+ (abort-to-prompt %unify-prompt-tag))
+
+(define (maybe-substitute var other dict)
+ (cond
+ ;; Tautology: matching 2 vars that are the same.
+ ((and (type-variable? other) (eq? var other))
+ dict)
+ ;; Variable has been bound to some other value, recursively follow
+ ;; it and unify.
+ ((assq-ref dict var) =>
+ (lambda (type)
+ (unify type other dict)))
+ ;; Substitute variable for value.
+ (else
+ (or (substitute var other dict)
+ (unify-fail)))))
(define (constant? obj)
(and (not (type-variable? obj))
(not (procedure-type? obj))
(not (list? obj))))
-(define-matcher (unify:fail a b)
- (define (fail-unifier dict succeed fail)
- (fail))
- fail-unifier)
-
-(define-matcher (unify:constants ((? constant? a) . rest-a) ((? constant? b) . rest-b))
- (define (constant-unifier dict succeed fail)
- (if (eqv? a b)
- (begin
- (succeed dict fail rest-a rest-b))
- (fail)))
- constant-unifier)
-
-(define-matcher (unify:lists (a . rest-a) (b . rest-b))
- (define (list-unifier dict succeed fail)
- ((unify:dispatch a b)
- dict
- (lambda (dict* fail* _null-a _null-b)
- (succeed dict* fail* rest-a rest-b))
- fail))
- list-unifier)
-
-(define-matcher (unify:variable-left (and ((? type-variable? _) . _) a) b)
- (maybe-substitute a b))
-
-(define-matcher (unify:variable-right a (and ((? type-variable? _) . _) b))
- (maybe-substitute b a))
-
-(define-matcher (unify:procedures ((? procedure-type? a) . rest-a)
- ((? procedure-type? b) . rest-b))
- (define (procedure-unifier dict succeed fail)
- ((unify:dispatch (cons (procedure-type-return-type a)
- (procedure-type-arg-types a))
- (cons (procedure-type-return-type b)
- (procedure-type-arg-types b)))
- dict
- (lambda (dict* fail* _null-a _null-b)
- (succeed dict* fail* rest-a rest-b))
- fail))
- procedure-unifier)
-
-(define %unify:dispatch
- (compose-matchers unify:constants
- unify:variable-left
+(define-matcher (unify:constants a b dict)
+ (pk 'unify:constants a b)
+ (if (eqv? a b)
+ dict
+ (unify-fail)))
+
+(define-matcher (unify:lists (a rest-a ...) (b rest-b ...) dict)
+ (pk 'unify:lists (cons a rest-a) (cons b rest-b))
+ (let ((dict* (unify a b dict)))
+ (unify rest-a rest-b dict*)))
+
+(define-matcher (unify:variable-left (? type-variable? a) b dict)
+ (pk 'unify:variable-left a b)
+ (maybe-substitute a b dict))
+
+(define-matcher (unify:variable-right a (? type-variable? b) dict)
+ (pk 'unify:variable-right a b)
+ (maybe-substitute b a dict))
+
+(define-matcher (unify:procedures (? procedure-type? a) (? procedure-type? b) dict)
+ (pk 'unify:procedures a b)
+ (define dict* (unify (procedure-type-parameter-types a)
+ (procedure-type-parameter-types b)
+ dict))
+ (unify (procedure-type-return-types a)
+ (procedure-type-return-types b)
+ dict*))
+
+(define unify
+ (compose-matchers unify:variable-left
unify:variable-right
unify:procedures
unify:lists
- unify:fail))
-
-(define (unify:dispatch a b)
- (define (dispatcher dict succeed fail)
- (if (and (null? a) (null? b))
- (succeed dict fail a b)
- ((%unify:dispatch a b)
- dict
- (lambda (dict* fail* rest-a rest-b)
- ((unify:dispatch rest-a rest-b)
- dict* succeed fail*))
- fail)))
- dispatcher)
-
-(define (%unify a b dict succeed)
- ((unify:dispatch (list a) (list b))
- dict
- (lambda (dict fail rest-a rest-b)
- (or (and (null? rest-a) (null? rest-b) (succeed dict))
- (fail)))
- (lambda ()
- #f)))
-
-(define (unify a b)
- (%unify a b '() identity))
+ unify:constants))
(define (unify-constraints constraints)
- (unify (map constraint-lhs constraints)
- (map constraint-rhs constraints)))
+ (call-with-prompt %unify-prompt-tag
+ (lambda ()
+ (unify (map constraint-lhs constraints)
+ (map constraint-rhs constraints)
+ '()))
+ (lambda (_k) #f)))
;;;
;;; Type Resolution
;;;
-(define (primitive? x)
- (or (number? x)
- (boolean? x)
- (symbol? x)))
-
-(define-matcher (resolve:primitive (? primitive? x) dict)
+(define-matcher (resolve:primitive x dict)
x)
(define-matcher (resolve:primitive-type (? primitive-type? type) dict)
@@ -373,30 +344,29 @@
type)))
(define-matcher (resolve:procedure-type (? procedure-type? type) dict)
- (make-procedure-type (resolve-types (procedure-type-arg-types type) dict)
- (resolve-types (procedure-type-return-type type) dict)))
+ (make-procedure-type (resolve-types (procedure-type-parameter-types type) dict)
+ (resolve-types (procedure-type-return-types type) dict)))
(define-matcher (resolve:list (exps ...) dict)
(map (lambda (texp) (resolve-types texp dict)) exps))
(define resolve-types
- (compose-matchers resolve:primitive
- resolve:primitive-type
+ (compose-matchers resolve:primitive-type
resolve:type-variable
resolve:procedure-type
- resolve:list))
+ resolve:list
+ resolve:primitive))
(define (infer-types exp)
(let* ((texp (annotate-exp* exp))
- (constraints (program-constraints texp))
+ (constraints (pk 'constraints (program-constraints texp)))
(substitutions (unify-constraints constraints)))
- (unless substitutions
- (error "type mismatch" texp))
- (resolve-types texp substitutions)))
+ (and substitutions
+ (resolve-types texp substitutions))))
(infer-types #t)
(infer-types 6)
(infer-types 6.5)
(infer-types '(if #t 1 2))
(infer-types '((lambda (x) x) 1))
-(infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x)))
+;;(infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x)))