From 1fca41ed98b74bf4224ada808489b4b770d3ddf6 Mon Sep 17 00:00:00 2001 From: David Thompson Date: Sat, 7 Jan 2023 20:37:56 -0500 Subject: Type quantifiers!!!! --- infer2.scm | 283 ++++++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 188 insertions(+), 95 deletions(-) diff --git a/infer2.scm b/infer2.scm index b7211c7..03a8d32 100644 --- a/infer2.scm +++ b/infer2.scm @@ -4,6 +4,17 @@ (srfi srfi-9) (srfi srfi-9 gnu)) +(define (print-list lst) + (for-each (lambda (x) (format #t "~a\n" x)) lst)) + +(define (difference a b) + (match a + (() b) + ((x . rest) + (if (memq x b) + (difference rest (delq x b)) + (cons x (difference rest b)))))) + (define-syntax-rule (define-matcher (name rules ...) body ...) (define name (match-lambda* @@ -23,7 +34,7 @@ ;;; -;;; Typed expression annotation +;;; Types and environments ;;; (define-record-type @@ -74,6 +85,12 @@ (procedure-type-return-types type))) (set-record-type-printer! display-procedure-type) +(define-record-type + (make-for-all-type variables type) + for-all-type? + (variables for-all-type-variables) + (type for-all-type-type)) + (define (type? obj) (or (primitive-type? obj) (type-variable? obj) @@ -83,8 +100,117 @@ '()) (define (lookup var env) - (or (assq-ref env var) - (error "unbound variable" var))) + (let ((type (assq-ref env var))) + (cond + ((type-variable? type) type) + ((for-all-type? type) (instantiate type)) + (else (error "unbound variable" var))))) + +(define (occurs-in? a b) + (cond + ((and (type-variable? a) (type-variable? b)) + (eq? a b)) + ((and (type-variable? a) (procedure-type? b)) + (or (occurs-in? a (procedure-type-parameter-types b)) + (occurs-in? a (procedure-type-return-types b)))) + ((and (type? a) (list? b)) + (any (lambda (b*) (occurs-in? a b*)) b)) + (else #f))) + +(define (apply-substitution-to-type type from to) + (cond + ((primitive-type? type) type) + ((type-variable? type) + (if (eq? type from) to type)) + ((procedure-type? type) + (make-procedure-type + (map (lambda (param-type) + (apply-substitution-to-type param-type from to)) + (procedure-type-parameter-types type)) + (map (lambda (return-type) + (apply-substitution-to-type return-type from to)) + (procedure-type-return-types type)))))) + +(define (apply-substitutions-to-type type subs) + (let loop ((type type) + (subs subs)) + (match subs + (() type) + (((from . to) . rest) + (loop (apply-substitution-to-type type from to) + rest))))) + +(define (apply-substitution-to-env from to env) + (map (match-lambda + ((a . b) + (cons (if (eq? a from) to a) + (if (eq? b from) to b)))) + env)) + +(define (substitute from to env) + (let* ((to* (apply-substitutions-to-type to env))) + (and (not (occurs-in? from to*)) + (let ((env* (apply-substitution-to-env from to* env))) + (pk 'add-to-env from to*) + (alist-cons from to* env*))))) + +(define (free-variables-in-type type) + (cond + ((primitive-type? type) '()) + ((type-variable? type) type) + ((procedure-type? type) + (delete-duplicates + (append (map free-variables-in-type + (procedure-type-parameter-types type)) + (map free-variables-in-type + (procedure-type-return-types type))))))) + +(define (free-variables-in-for-all for-all) + (difference (for-all-type-variables for-all) + (free-variables-in-type (for-all-type-type for-all)))) + +(define (free-variables-in-env env) + (delete-duplicates + (let loop ((env env)) + (match env + (() '()) + (((_ . type) . rest) + (cond + ((type-variable? type) + (cons (free-variables-in-type type) + (loop rest))) + ((for-all-type? type) + (cons (free-variables-in-for-all type) + (loop rest))) + (else + (loop rest)))))))) + +(define (instantiate for-all) + (define subs + (map (lambda (var) + (cons var (fresh-type-variable))) + (for-all-type-variables for-all))) + (pk 'instantiate for-all) + (apply-substitutions-to-type (for-all-type-type for-all) subs)) + +(define (generalize type env) + (define quantifiers + (difference (free-variables-in-type type) + (free-variables-in-env env))) + (pk 'quantifiers quantifiers) + (if (null? quantifiers) + type + (make-for-all-type quantifiers type))) + +(define (constraints->env constraints) + (map (lambda (c) + (cons (constraint-lhs c) (constraint-rhs c))) + constraints)) + + +;;; +;;; Typed expression annotation +;;; (define (make-texp types exp) `(t ,types ,exp)) @@ -102,9 +228,10 @@ (match texp (('t _ exp) exp))) -(define (display-typed-expression texp port) - (format port "#" (texp-types texp) (texp-exp texp))) -(set-record-type-printer! display-typed-expression) +(define (single-type texp) + (match (texp-types texp) + ((type) type) + (_ (error "expected only 1 type" texp)))) (define-matcher (annotate:bool (? boolean? b) env) (make-texp (list bool-type) b)) @@ -126,29 +253,25 @@ (define-matcher (annotate:values ('values vals ...) env) (define vals* (map (lambda (val) (annotate-exp val env)) vals)) - (define val-types - (map (lambda (val) - (match (texp-types val) - ((type) type) - (types (error "expressions in values must return 1 value" val)))) - vals*)) + (define val-types (map single-type vals*)) (make-texp val-types `(values ,@vals*))) (define-matcher (annotate:let ('let (((? symbol? vars) vals) ...) body) env) (define vals* (map (lambda (val) (annotate-exp val env)) vals)) (define val-types (map (lambda (val) - (match (texp-types val) - ((type) type) - (_ (error "let bindings must return 1 value" val)))) + (generalize (single-type val) env)) vals*)) + (define vals** (map (lambda (val type) + (make-texp (list type) (texp-exp val))) + vals* val-types)) (define env* (append (map cons vars val-types) env)) (define body* (annotate-exp body env*)) - (make-texp (map (lambda (_type) (fresh-type-variable)) (texp-types body*)) - `(let ,(map list vars vals*) ,body*))) + (make-texp (texp-types body*) + `(let ,(map list vars vals**) ,body*))) (define-matcher (annotate:begin ('begin exps ... last) env) (define last* (annotate-exp last env)) - (make-texp (map (lambda (_type) (fresh-type-variable)) (texp-types last*)) + (make-texp (texp-types last*) `(begin ,@(map (lambda (exp) (annotate-exp exp env)) exps) ,last*))) @@ -156,17 +279,20 @@ (define parameter-types (map (lambda (_name) (fresh-type-variable)) args)) (define env* (append (map cons args parameter-types) env)) (define body* (annotate-exp body env*)) - (define return-types - (map (lambda (_type) (fresh-type-variable)) (texp-types body*))) + (define return-types (texp-types body*)) (make-texp (list (make-procedure-type parameter-types return-types)) `(lambda ,args ,body*))) (define-matcher (annotate:call (operator args ...) env) (define operator* (annotate-exp operator env)) + (define args* (map (lambda (arg) + (annotate-exp arg env)) + args)) (make-texp (map (lambda (_type) (fresh-type-variable)) (texp-types operator*)) (cons operator* - (map (lambda (arg) (annotate-exp arg env)) args)))) + args*))) + (define annotate-exp (compose-matchers annotate:bool @@ -181,8 +307,7 @@ annotate:call)) (define (annotate-exp* exp) - (parameterize ((unique-counter 0)) - (annotate-exp exp (top-level-env)))) + (annotate-exp exp (top-level-env))) ;;; @@ -221,28 +346,36 @@ (define-matcher (constrain:let ((? type? types) ...) ('let (((? symbol? vars) (? texp? vals)) ...) (? texp? body))) - (cons (constrain types (texp-types body)) - (append (program-constraints body) - (append-map program-constraints vals)))) + (append (program-constraints body) + (append-map program-constraints vals))) (define-matcher (constrain:begin ((? type? types) ...) ('begin (? texp? texps) ... (? texp? last))) - (cons (constrain types (texp-types last)) - (append (program-constraints last) - (append-map program-constraints texps)))) + (append (program-constraints last) + (append-map program-constraints texps))) + +(define-matcher (constrain:for-all-lambda ((? for-all-type? type)) + ('lambda ((? symbol? args) ...) + (? texp? body))) + (cons (pk 'constrain-for-all-lambda + (constrain (procedure-type-return-types + (for-all-type-type type)) + (texp-types body))) + (program-constraints body))) (define-matcher (constrain:lambda ((? procedure-type? type)) ('lambda ((? symbol? args) ...) (? texp? body))) - (cons (constrain (procedure-type-return-types type) - (texp-types body)) + (cons (pk 'constrain-lambda + (constrain (procedure-type-return-types type) + (texp-types body))) (program-constraints body))) (define-matcher (constrain:call ((? type? types) ...) ((? texp? operator) (? texp? operands) ...)) - (cons (constrain (texp-types operator) - (list (make-procedure-type (map texp-type operands) - types))) + (cons (pk 'constrain-call (constrain (texp-types operator) + (list (make-procedure-type (map single-type operands) + types)))) (append (program-constraints operator) (append-map program-constraints operands)))) @@ -251,63 +384,19 @@ constrain:values constrain:let constrain:begin + constrain:for-all-lambda constrain:lambda constrain:call constrain:other)) (define (program-constraints texp) - (%program-constraints (texp-type texp) (texp-exp texp))) - - -;;; -;;; Occurs check -;;; - -(define-matcher (occurs:default _ _) - #f) - -(define-matcher (occurs:list (a rest-a ...) (b rest-b ...)) - (or (occurs-in? a b) - (occurs-in? rest-a rest-b))) - -(define-matcher (occurs:variable (? type-variable? a) (? type-variable? b)) - (eq? a b)) - -(define-matcher (occurs:procedure (? type-variable? v) (? procedure-type? p)) - (or (occurs-in? v (procedure-type-parameter-types p)) - (occurs-in? v (procedure-type-return-types p)))) - -(define occurs-in? - (compose-matchers occurs:list - occurs:variable - occurs:procedure - occurs:default)) + (%program-constraints (texp-types texp) (texp-exp texp))) ;;; ;;; Unification ;;; -(define (substitute-term term dict) - (match dict - (() term) - (((from . to) . rest) - (if (eq? term from) - to - (substitute-term term rest))))) - -(define (substitute-dict var term dict) - (map (match-lambda - ((a . b) - (cons (if (eq? a var) term a) - (if (eq? b var) term b)))) - dict)) - -(define (substitute var other dict) - (let ((other* (substitute-term other dict))) - (and (not (occurs-in? var other*)) - (alist-cons var other* (substitute-dict var other* dict))))) - (define %unify-prompt-tag (make-prompt-tag 'unify)) (define (unify-fail) @@ -318,14 +407,17 @@ (cond ;; Tautology: matching 2 vars that are the same. ((and (type-variable? other) (eq? var other)) + (pk 'tautology var other) dict) ;; Variable has been bound to some other value, recursively follow ;; it and unify. ((assq-ref dict var) => (lambda (type) + (pk 'forward var other) (unify type other dict))) ;; Substitute variable for value. (else + (pk 'substitute var other) (or (substitute var other dict) (unify-fail))))) @@ -358,6 +450,7 @@ (define dict* (unify (procedure-type-parameter-types a) (procedure-type-parameter-types b) dict)) + (pk 'dict* dict*) (unify (procedure-type-return-types a) (procedure-type-return-types b) dict*)) @@ -383,22 +476,27 @@ ;;; (define-matcher (resolve:primitive x dict) + (pk 'resolve-primitive) x) (define-matcher (resolve:primitive-type (? primitive-type? type) dict) + (pk 'resolve-primitive-type type) type) (define-matcher (resolve:type-variable (? type-variable? var) dict) + (pk 'resolve-var var) (let ((type (assq-ref dict var))) - (if (or (not type) (type-variable? type)) - (error "cannot determine type" var) - type))) + (if type + (resolve-types type dict) + var))) (define-matcher (resolve:procedure-type (? procedure-type? type) dict) + (pk 'resolve-proc type) (make-procedure-type (resolve-types (procedure-type-parameter-types type) dict) (resolve-types (procedure-type-return-types type) dict))) (define-matcher (resolve:list (exps ...) dict) + (pk 'resolve-list exps) (map (lambda (texp) (resolve-types texp dict)) exps)) (define resolve-types @@ -409,15 +507,10 @@ resolve:primitive)) (define (infer-types exp) - (let* ((texp (annotate-exp* exp)) - (constraints (pk 'constraints (program-constraints texp))) - (substitutions (unify-constraints constraints))) - (and substitutions - (resolve-types texp substitutions)))) - -(infer-types #t) -(infer-types 6) -(infer-types 6.5) -(infer-types '(if #t 1 2)) -(infer-types '((lambda (x) x) 1)) -;;(infer-types '((lambda (f) (if (f #t) (f 1) (f 2))) (lambda (x) x))) + (parameterize ((unique-counter 0)) + (pk 'begin-inference) + (let* ((texp (pk 'texp (annotate-exp* exp))) + (constraints (pk 'constraints (program-constraints texp))) + (substitutions (pk 'substitutions (unify-constraints constraints)))) + (and substitutions + (resolve-types texp substitutions))))) -- cgit v1.2.3