summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Thompson <dthompson2@worcester.edu>2023-01-07 20:37:56 -0500
committerDavid Thompson <dthompson2@worcester.edu>2023-06-08 08:14:41 -0400
commit1fca41ed98b74bf4224ada808489b4b770d3ddf6 (patch)
tree8f081072be2cf04c694afa75177fc33bfabd1a7c
parent2c006b1517eb929c797efb350a15581cf7336f02 (diff)
Type quantifiers!!!!
-rw-r--r--infer2.scm283
1 files changed, 188 insertions, 95 deletions
diff --git a/infer2.scm b/infer2.scm
index b7211c7..03a8d32 100644
--- a/infer2.scm
+++ b/infer2.scm
@@ -4,6 +4,17 @@
(srfi srfi-9)
(srfi srfi-9 gnu))
+(define (print-list lst)
+ (for-each (lambda (x) (format #t "~a\n" x)) lst))
+
+(define (difference a b)
+ (match a
+ (() b)
+ ((x . rest)
+ (if (memq x b)
+ (difference rest (delq x b))
+ (cons x (difference rest b))))))
+
(define-syntax-rule (define-matcher (name rules ...) body ...)
(define name
(match-lambda*
@@ -23,7 +34,7 @@
;;;
-;;; Typed expression annotation
+;;; Types and environments
;;;
(define-record-type <primitive-type>
@@ -74,6 +85,12 @@
(procedure-type-return-types type)))
(set-record-type-printer! <procedure-type> display-procedure-type)
+(define-record-type <for-all-type>
+ (make-for-all-type variables type)
+ for-all-type?
+ (variables for-all-type-variables)
+ (type for-all-type-type))
+
(define (type? obj)
(or (primitive-type? obj)
(type-variable? obj)
@@ -83,8 +100,117 @@
'())
(define (lookup var env)
- (or (assq-ref env var)
- (error "unbound variable" var)))
+ (let ((type (assq-ref env var)))
+ (cond
+ ((type-variable? type) type)
+ ((for-all-type? type) (instantiate type))
+ (else (error "unbound variable" var)))))
+
+(define (occurs-in? a b)
+ (cond
+ ((and (type-variable? a) (type-variable? b))
+ (eq? a b))
+ ((and (type-variable? a) (procedure-type? b))
+ (or (occurs-in? a (procedure-type-parameter-types b))
+ (occurs-in? a (procedure-type-return-types b))))
+ ((and (type? a) (list? b))
+ (any (lambda (b*) (occurs-in? a b*)) b))
+ (else #f)))
+
+(define (apply-substitution-to-type type from to)
+ (cond
+ ((primitive-type? type) type)
+ ((type-variable? type)
+ (if (eq? type from) to type))
+ ((procedure-type? type)
+ (make-procedure-type
+ (map (lambda (param-type)
+ (apply-substitution-to-type param-type from to))
+ (procedure-type-parameter-types type))
+ (map (lambda (return-type)
+ (apply-substitution-to-type return-type from to))
+ (procedure-type-return-types type))))))
+
+(define (apply-substitutions-to-type type subs)
+ (let loop ((type type)
+ (subs subs))
+ (match subs
+ (() type)
+ (((from . to) . rest)
+ (loop (apply-substitution-to-type type from to)
+ rest)))))
+
+(define (apply-substitution-to-env from to env)
+ (map (match-lambda
+ ((a . b)
+ (cons (if (eq? a from) to a)
+ (if (eq? b from) to b))))
+ env))
+
+(define (substitute from to env)
+ (let* ((to* (apply-substitutions-to-type to env)))
+ (and (not (occurs-in? from to*))
+ (let ((env* (apply-substitution-to-env from to* env)))
+ (pk 'add-to-env from to*)
+ (alist-cons from to* env*)))))
+
+(define (free-variables-in-type type)
+ (cond
+ ((primitive-type? type) '())
+ ((type-variable? type) type)
+ ((procedure-type? type)
+ (delete-duplicates
+ (append (map free-variables-in-type
+ (procedure-type-parameter-types type))
+ (map free-variables-in-type
+ (procedure-type-return-types type)))))))
+
+(define (free-variables-in-for-all for-all)
+ (difference (for-all-type-variables for-all)
+ (free-variables-in-type (for-all-type-type for-all))))
+
+(define (free-variables-in-env env)
+ (delete-duplicates
+ (let loop ((env env))
+ (match env
+ (() '())
+ (((_ . type) . rest)
+ (cond
+ ((type-variable? type)
+ (cons (free-variables-in-type type)
+ (loop rest)))
+ ((for-all-type? type)
+ (cons (free-variables-in-for-all type)
+ (loop rest)))
+ (else
+ (loop rest))))))))
+
+(define (instantiate for-all)
+ (define subs
+ (map (lambda (var)
+ (cons var (fresh-type-variable)))
+ (for-all-type-variables for-all)))
+ (pk 'instantiate for-all)
+ (apply-substitutions-to-type (for-all-type-type for-all) subs))
+
+(define (generalize type env)
+ (define quantifiers
+ (difference (free-variables-in-type type)
+ (free-variables-in-env env)))
+ (pk 'quantifiers quantifiers)
+ (if (null? quantifiers)
+ type
+ (make-for-all-type quantifiers type)))
+
+(define (constraints->env constraints)
+ (map (lambda (c)
+ (cons (constraint-lhs c) (constraint-rhs c)))
+ constraints))
+
+
+;;;
+;;; Typed expression annotation
+;;;
(define (make-texp types exp)
`(t ,types ,exp))
@@ -102,9 +228,10 @@
(match texp
(('t _ exp) exp)))
-(define (display-typed-expression texp port)
- (format port "#<texp ~a ~a>" (texp-types texp) (texp-exp texp)))
-(set-record-type-printer! <typed-expression> display-typed-expression)
+(define (single-type texp)
+ (match (texp-types texp)
+ ((type) type)
+ (_ (error "expected only 1 type" texp))))
(define-matcher (annotate:bool (? boolean? b) env)
(make-texp (list bool-type) b))
@@ -126,29 +253,25 @@
(define-matcher (annotate:values ('values vals ...) env)
(define vals* (map (lambda (val) (annotate-exp val env)) vals))
- (define val-types
- (map (lambda (val)
- (match (texp-types val)
- ((type) type)
- (types (error "expressions in values must return 1 value" val))))
- vals*))
+ (define val-types (map single-type vals*))
(make-texp val-types `(values ,@vals*)))
(define-matcher (annotate:let ('let (((? symbol? vars) vals) ...) body) env)
(define vals* (map (lambda (val) (annotate-exp val env)) vals))
(define val-types (map (lambda (val)
- (match (texp-types val)
- ((type) type)
- (_ (error "let bindings must return 1 value" val))))
+ (generalize (single-type val) env))
vals*))
+ (define vals** (map (lambda (val type)
+ (make-texp (list type) (texp-exp val)))
+ vals* val-types))
(define env* (append (map cons vars val-types) env))
(define body* (annotate-exp body env*))
- (make-texp (map (lambda (_type) (fresh-type-variable)) (texp-types body*))
- `(let ,(map list vars vals*) ,body*)))
+ (make-texp (texp-types body*)
+ `(let ,(map list vars vals**) ,body*)))
(define-matcher (annotate:begin ('begin exps ... last) env)
(define last* (annotate-exp last env))
- (make-texp (map (lambda (_type) (fresh-type-variable)) (texp-types last*))
+ (make-texp (texp-types last*)
`(begin ,@(map (lambda (exp) (annotate-exp exp env)) exps)
,last*)))
@@ -156,17 +279,20 @@
(define parameter-types (map (lambda (_name) (fresh-type-variable)) args))
(define env* (append (map cons args parameter-types) env))
(define body* (annotate-exp body env*))
- (define return-types
- (map (lambda (_type) (fresh-type-variable)) (texp-types body*)))
+ (define return-types (texp-types body*))
(make-texp (list (make-procedure-type parameter-types return-types))
`(lambda ,args ,body*)))
(define-matcher (annotate:call (operator args ...) env)
(define operator* (annotate-exp operator env))
+ (define args* (map (lambda (arg)
+ (annotate-exp arg env))
+ args))
(make-texp (map (lambda (_type) (fresh-type-variable))
(texp-types operator*))
(cons operator*
- (map (lambda (arg) (annotate-exp arg env)) args))))
+ args*)))
+
(define annotate-exp
(compose-matchers annotate:bool
@@ -181,8 +307,7 @@
annotate:call))
(define (annotate-exp* exp)
- (parameterize ((unique-counter 0))
- (annotate-exp exp (top-level-env))))
+ (annotate-exp exp (top-level-env)))
;;;
@@ -221,28 +346,36 @@
(define-matcher (constrain:let ((? type? types) ...)
('let (((? symbol? vars) (? texp? vals)) ...)
(? texp? body)))
- (cons (constrain types (texp-types body))
- (append (program-constraints body)
- (append-map program-constraints vals))))
+ (append (program-constraints body)
+ (append-map program-constraints vals)))
(define-matcher (constrain:begin ((? type? types) ...)
('begin (? texp? texps) ... (? texp? last)))
- (cons (constrain types (texp-types last))
- (append (program-constraints last)
- (append-map program-constraints texps))))
+ (append (program-constraints last)
+ (append-map program-constraints texps)))
+
+(define-matcher (constrain:for-all-lambda ((? for-all-type? type))
+ ('lambda ((? symbol? args) ...)
+ (? texp? body)))
+ (cons (pk 'constrain-for-all-lambda
+ (constrain (procedure-type-return-types
+ (for-all-type-type type))
+ (texp-types body)))
+ (program-constraints body)))
(define-matcher (constrain:lambda ((? procedure-type? type))
('lambda ((? symbol? args) ...)
(? texp? body)))
- (cons (constrain (procedure-type-return-types type)
- (texp-types body))
+ (cons (pk 'constrain-lambda
+ (constrain (procedure-type-return-types type)
+ (texp-types body)))
(program-constraints body)))
(define-matcher (constrain:call ((? type? types) ...)
((? texp? operator) (? texp? operands) ...))
- (cons (constrain (texp-types operator)
- (list (make-procedure-type (map texp-type operands)
- types)))
+ (cons (pk 'constrain-call (constrain (texp-types operator)
+ (list (make-procedure-type (map single-type operands)
+ types))))
(append (program-constraints operator)
(append-map program-constraints operands))))
@@ -251,63 +384,19 @@
constrain:values
constrain:let
constrain:begin
+ constrain:for-all-lambda
constrain:lambda
constrain:call
constrain:other))
(define (program-constraints texp)
- (%program-constraints (texp-type texp) (texp-exp texp)))
-
-
-;;;
-;;; Occurs check
-;;;
-
-(define-matcher (occurs:default _ _)
- #f)
-
-(define-matcher (occurs:list (a rest-a ...) (b rest-b ...))
- (or (occurs-in? a b)
- (occurs-in? rest-a rest-b)))
-
-(define-matcher (occurs:variable (? type-variable? a) (? type-variable? b))
- (eq? a b))
-
-(define-matcher (occurs:procedure (? type-variable? v) (? procedure-type? p))
- (or (occurs-in? v (procedure-type-parameter-types p))
- (occurs-in? v (procedure-type-return-types p))))
-
-(define occurs-in?
- (compose-matchers occurs:list
- occurs:variable
- occurs:procedure
- occurs:default))
+ (%program-constraints (texp-types texp) (texp-exp texp)))
;;;
;;; Unification
;;;
-(define (substitute-term term dict)
- (match dict
- (() term)
- (((from . to) . rest)
- (if (eq? term from)
- to
- (substitute-term term rest)))))
-
-(define (substitute-dict var term dict)
- (map (match-lambda
- ((a . b)
- (cons (if (eq? a var) term a)
- (if (eq? b var) term b))))
- dict))
-
-(define (substitute var other dict)
- (let ((other* (substitute-term other dict)))
- (and (not (occurs-in? var other*))
- (alist-cons var other* (substitute-dict var other* dict)))))
-
(define %unify-prompt-tag (make-prompt-tag 'unify))
(define (unify-fail)
@@ -318,14 +407,17 @@
(cond
;; Tautology: matching 2 vars that are the same.
((and (type-variable? other) (eq? var other))
+ (pk 'tautology var other)
dict)
;; Variable has been bound to some other value, recursively follow
;; it and unify.
((assq-ref dict var) =>
(lambda (type)
+ (pk 'forward var other)
(unify type other dict)))
;; Substitute variable for value.
(else
+ (pk 'substitute var other)
(or (substitute var other dict)
(unify-fail)))))
@@ -358,6 +450,7 @@
(define dict* (unify (procedure-type-parameter-types a)
(procedure-type-parameter-types b)
dict))
+ (pk 'dict* dict*)
(unify (procedure-type-return-types a)
(procedure-type-return-types b)
dict*))
@@ -383,22 +476,27 @@
;;;
(define-matcher (resolve:primitive x dict)
+ (pk 'resolve-primitive)
x)
(define-matcher (resolve:primitive-type (? primitive-type? type) dict)
+ (pk 'resolve-primitive-type type)
type)
(define-matcher (resolve:type-variable (? type-variable? var) dict)
+ (pk 'resolve-var var)
(let ((type (assq-ref dict var)))
- (if (or (not type) (type-variable? type))
- (error "cannot determine type" var)
- type)))
+ (if type
+ (resolve-types type dict)
+ var)))
(define-matcher (resolve:procedure-type (? procedure-type? type) dict)
+ (pk 'resolve-proc type)
(make-procedure-type (resolve-types (procedure-type-parameter-types type) dict)
(resolve-types (procedure-type-return-types type) dict)))
(define-matcher (resolve:list (exps ...) dict)
+ (pk 'resolve-list exps)
(map (lambda (texp) (resolve-types texp dict)) exps))
(define resolve-types
@@ -409,15 +507,10 @@
resolve:primitive))
(define (infer-types exp)
- (let* ((texp (annotate-exp* exp))
- (constraints (pk 'constraints (program-constraints texp)))
- (substitutions (unify-constraints constraints)))
- (and substitutions
- (resolve-types texp substitutions))))
-
-(infer-types #t)
-(infer-types 6)
-(infer-types 6.5)
-(infer-types '(if #t 1 2))
-(infer-types '((lambda (x) x) 1))
-;;(infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x)))
+ (parameterize ((unique-counter 0))
+ (pk 'begin-inference)
+ (let* ((texp (pk 'texp (annotate-exp* exp)))
+ (constraints (pk 'constraints (program-constraints texp)))
+ (substitutions (pk 'substitutions (unify-constraints constraints))))
+ (and substitutions
+ (resolve-types texp substitutions)))))