summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--infer.scm211
1 files changed, 211 insertions, 0 deletions
diff --git a/infer.scm b/infer.scm
new file mode 100644
index 0000000..c9da6ef
--- /dev/null
+++ b/infer.scm
@@ -0,0 +1,211 @@
+(use-modules (ice-9 match)
+ (srfi srfi-1)
+ (srfi srfi-9))
+
+(define-record-type <named-type>
+ (make-named-type name)
+ named-type?
+ (name named-type-name))
+
+(define-record-type <variable-type>
+ (make-variable-type name)
+ variable-type?
+ (name variable-type-name))
+
+(define-record-type <function-type>
+ (make-function-type from to)
+ function-type?
+ (from function-type-from)
+ (to function-type-to))
+
+(define int (make-named-type 'int))
+(define bool (make-named-type 'bool))
+
+(define unique-counter (make-parameter 0))
+
+(define (unique-number)
+ (let ((n (unique-counter)))
+ (unique-counter (+ n 1))
+ n))
+
+(define* (unique-identifier #:optional (prefix 'T))
+ (string->symbol
+ (format #f "~a~a" prefix (unique-number))))
+
+(define (make-fresh-variable-type)
+ (make-variable-type (unique-identifier)))
+
+(define (substitute-type subs type)
+ (cond
+ ((named-type? type)
+ type)
+ ((variable-type? type)
+ (or (assq-ref subs type) type))
+ ((function-type? type)
+ (let* ((from (function-type-from type))
+ (to (function-type-to type))
+ (from* (substitute-type subs from))
+ (to* (substitute-type subs to)))
+ (if (and (eq? from from*) (eq? to to*))
+ type
+ (make-function-type from* to*))))))
+
+(define (substitute-env subs env)
+ (map (match-lambda
+ ((name . type)
+ (cons name (substitute-type subs type))))
+ env))
+
+(define (substitute-constraints subs constraints)
+ (map (match-lambda
+ ((a . b)
+ (cons (substitute-type subs a) (substitute-type subs b))))
+ constraints))
+
+(define (contains? a b)
+ (cond
+ ((variable-type? a)
+ (eq? a b))
+ ((named-type? a) #f)
+ ((function-type? a)
+ (or (contains? (function-type-from a) b)
+ (contains? (function-type-to a) b)))))
+
+(define (make-constraints exp env)
+ (match exp
+ ((and (? number?) (? exact-integer?))
+ (values (list (cons exp int)) '()))
+ ((? boolean?)
+ (values (list (cons exp bool)) '()))
+ ((? symbol?)
+ (values (list (cons exp
+ (or (assq-ref env exp)
+ (error "unbound variable" exp))))
+ '()))
+ (('lambda (arg) body)
+ (define arg-type (make-fresh-variable-type))
+ (define-values (body-env body-constraints)
+ (make-constraints body (alist-cons arg arg-type env)))
+ (values (cons* (cons exp (make-function-type arg-type
+ (assq-ref body-env body)))
+ (cons arg arg-type)
+ body-env)
+ body-constraints))
+ (('if test consequent alternate)
+ (define-values (test-env test-constraints)
+ (make-constraints test env))
+ (define-values (consequent-env consequent-constraints)
+ (make-constraints consequent env))
+ (define-values (alternate-env alternate-constraints)
+ (make-constraints alternate env))
+ (values (cons (cons exp (assq-ref consequent-env consequent))
+ (append test-env
+ consequent-env
+ alternate-env
+ env))
+ (append (list (cons (assq-ref test-env test) bool)
+ (cons (assq-ref consequent-env consequent)
+ (assq-ref alternate-env alternate)))
+ test-constraints
+ consequent-constraints
+ alternate-constraints)))
+ ((proc args ...)
+ (define-values (%arg-envs %arg-constraints)
+ (unzip2
+ (map (lambda (arg)
+ (call-with-values (lambda ()
+ (make-constraints arg env))
+ list))
+ args)))
+ (define arg-env (concatenate %arg-envs))
+ (define arg-constraints (concatenate %arg-constraints))
+ (define-values (proc-env proc-constraints)
+ (make-constraints proc env))
+ ;; (define-values (arg-env arg-constraints)
+ ;; (make-constraints arg env))
+ (define return-type (make-fresh-variable-type))
+ (define call-type (make-function-type (assq-ref arg-env (car args))
+ return-type))
+ (values (append (list (cons exp return-type))
+ proc-env
+ arg-env)
+ (append (list (cons (assq-ref proc-env proc) call-type))
+ proc-constraints
+ arg-constraints)))
+ (_
+ (error "invalid expression" exp))))
+
+(define %default-env
+ `((not . ,(make-function-type bool bool))
+ (add1 . ,(make-function-type int int))
+ (sub1 . ,(make-function-type int int))))
+
+(define (make-constraints* exp)
+ (parameterize ((unique-counter 0))
+ (define-values (env constraints)
+ (make-constraints exp %default-env))
+ (values (delete-duplicates env) constraints)))
+
+(define (unify a b)
+ (define (sub-var var type)
+ (cond
+ ;; Type is also a variable, so we can't do anything.
+ ((eq? var type)
+ '())
+ ;; Variable appears within type, which is not allowed.
+ ((contains? type var)
+ (error "circular reference" var type))
+ (else
+ (list (cons var type)))))
+ (cond
+ ;; A and B are the same simple type (like int or bool.)
+ ((and (named-type? a) (named-type? b) (eq? a b))
+ '())
+ ;; A or B is a type variable.
+ ((variable-type? a)
+ (sub-var a b))
+ ((variable-type? b)
+ (sub-var b a))
+ ;; A and B are function types.
+ ((and (function-type? a) (function-type? b))
+ (let* ((a-subs (unify (function-type-from a) (function-type-from b)))
+ (b-subs (unify (substitute-type a-subs (function-type-to a))
+ (substitute-type a-subs (function-type-to b)))))
+ (append a-subs b-subs)))
+ ;; Oh no.
+ (else
+ (error "type mismtach" a b))))
+
+;; Successively transform the type environment by applying
+;; constraints. If there are no type mismatches or other errors then
+;; a new type environment in which all type variables have been
+;; removed is returned.
+(define (solve-constraints env constraints)
+ (match constraints
+ (() env)
+ (((a . b) . rest)
+ ;; First, attempt to unify the 2 types in the constraint and get
+ ;; the substitutions that unification creates. Subsitutions need
+ ;; to be applied to the type environment *and* the remaining
+ ;; constraints.
+ (let* ((new-subs (unify a b)))
+ (solve-constraints (substitute-env new-subs env)
+ (substitute-constraints new-subs rest))))))
+
+(define (infer exp)
+ (define-values (env constraints)
+ (make-constraints* exp))
+ (assq-ref (solve-constraints env constraints) exp))
+
+(define (test-equal a b)
+ (unless (equal? a b)
+ (error "fail:" a b)))
+
+(test-equal (infer 6) int)
+(test-equal (infer #t) bool)
+(test-equal (infer #f) bool)
+(test-equal (false-if-exception (infer 'x)) #f)
+(test-equal (infer '(lambda (x) (not x))) (make-function-type bool bool))
+(test-equal (infer '((lambda (x) x) 6)) int)
+(test-equal (infer '((lambda (x) (if (not #t) (add1 x) (sub1 x))) 1)) int)
+(test-equal (false-if-exception (infer '((lambda (x) (if #t 1 x)) #f))) #f)