diff options
-rw-r--r-- | infer.scm | 211 |
1 files changed, 211 insertions, 0 deletions
diff --git a/infer.scm b/infer.scm new file mode 100644 index 0000000..c9da6ef --- /dev/null +++ b/infer.scm @@ -0,0 +1,211 @@ +(use-modules (ice-9 match) + (srfi srfi-1) + (srfi srfi-9)) + +(define-record-type <named-type> + (make-named-type name) + named-type? + (name named-type-name)) + +(define-record-type <variable-type> + (make-variable-type name) + variable-type? + (name variable-type-name)) + +(define-record-type <function-type> + (make-function-type from to) + function-type? + (from function-type-from) + (to function-type-to)) + +(define int (make-named-type 'int)) +(define bool (make-named-type 'bool)) + +(define unique-counter (make-parameter 0)) + +(define (unique-number) + (let ((n (unique-counter))) + (unique-counter (+ n 1)) + n)) + +(define* (unique-identifier #:optional (prefix 'T)) + (string->symbol + (format #f "~a~a" prefix (unique-number)))) + +(define (make-fresh-variable-type) + (make-variable-type (unique-identifier))) + +(define (substitute-type subs type) + (cond + ((named-type? type) + type) + ((variable-type? type) + (or (assq-ref subs type) type)) + ((function-type? type) + (let* ((from (function-type-from type)) + (to (function-type-to type)) + (from* (substitute-type subs from)) + (to* (substitute-type subs to))) + (if (and (eq? from from*) (eq? to to*)) + type + (make-function-type from* to*)))))) + +(define (substitute-env subs env) + (map (match-lambda + ((name . type) + (cons name (substitute-type subs type)))) + env)) + +(define (substitute-constraints subs constraints) + (map (match-lambda + ((a . b) + (cons (substitute-type subs a) (substitute-type subs b)))) + constraints)) + +(define (contains? a b) + (cond + ((variable-type? a) + (eq? a b)) + ((named-type? a) #f) + ((function-type? a) + (or (contains? (function-type-from a) b) + (contains? (function-type-to a) b))))) + +(define (make-constraints exp env) + (match exp + ((and (? number?) (? exact-integer?)) + (values (list (cons exp int)) '())) + ((? boolean?) + (values (list (cons exp bool)) '())) + ((? symbol?) + (values (list (cons exp + (or (assq-ref env exp) + (error "unbound variable" exp)))) + '())) + (('lambda (arg) body) + (define arg-type (make-fresh-variable-type)) + (define-values (body-env body-constraints) + (make-constraints body (alist-cons arg arg-type env))) + (values (cons* (cons exp (make-function-type arg-type + (assq-ref body-env body))) + (cons arg arg-type) + body-env) + body-constraints)) + (('if test consequent alternate) + (define-values (test-env test-constraints) + (make-constraints test env)) + (define-values (consequent-env consequent-constraints) + (make-constraints consequent env)) + (define-values (alternate-env alternate-constraints) + (make-constraints alternate env)) + (values (cons (cons exp (assq-ref consequent-env consequent)) + (append test-env + consequent-env + alternate-env + env)) + (append (list (cons (assq-ref test-env test) bool) + (cons (assq-ref consequent-env consequent) + (assq-ref alternate-env alternate))) + test-constraints + consequent-constraints + alternate-constraints))) + ((proc args ...) + (define-values (%arg-envs %arg-constraints) + (unzip2 + (map (lambda (arg) + (call-with-values (lambda () + (make-constraints arg env)) + list)) + args))) + (define arg-env (concatenate %arg-envs)) + (define arg-constraints (concatenate %arg-constraints)) + (define-values (proc-env proc-constraints) + (make-constraints proc env)) + ;; (define-values (arg-env arg-constraints) + ;; (make-constraints arg env)) + (define return-type (make-fresh-variable-type)) + (define call-type (make-function-type (assq-ref arg-env (car args)) + return-type)) + (values (append (list (cons exp return-type)) + proc-env + arg-env) + (append (list (cons (assq-ref proc-env proc) call-type)) + proc-constraints + arg-constraints))) + (_ + (error "invalid expression" exp)))) + +(define %default-env + `((not . ,(make-function-type bool bool)) + (add1 . ,(make-function-type int int)) + (sub1 . ,(make-function-type int int)))) + +(define (make-constraints* exp) + (parameterize ((unique-counter 0)) + (define-values (env constraints) + (make-constraints exp %default-env)) + (values (delete-duplicates env) constraints))) + +(define (unify a b) + (define (sub-var var type) + (cond + ;; Type is also a variable, so we can't do anything. + ((eq? var type) + '()) + ;; Variable appears within type, which is not allowed. + ((contains? type var) + (error "circular reference" var type)) + (else + (list (cons var type))))) + (cond + ;; A and B are the same simple type (like int or bool.) + ((and (named-type? a) (named-type? b) (eq? a b)) + '()) + ;; A or B is a type variable. + ((variable-type? a) + (sub-var a b)) + ((variable-type? b) + (sub-var b a)) + ;; A and B are function types. + ((and (function-type? a) (function-type? b)) + (let* ((a-subs (unify (function-type-from a) (function-type-from b))) + (b-subs (unify (substitute-type a-subs (function-type-to a)) + (substitute-type a-subs (function-type-to b))))) + (append a-subs b-subs))) + ;; Oh no. + (else + (error "type mismtach" a b)))) + +;; Successively transform the type environment by applying +;; constraints. If there are no type mismatches or other errors then +;; a new type environment in which all type variables have been +;; removed is returned. +(define (solve-constraints env constraints) + (match constraints + (() env) + (((a . b) . rest) + ;; First, attempt to unify the 2 types in the constraint and get + ;; the substitutions that unification creates. Subsitutions need + ;; to be applied to the type environment *and* the remaining + ;; constraints. + (let* ((new-subs (unify a b))) + (solve-constraints (substitute-env new-subs env) + (substitute-constraints new-subs rest)))))) + +(define (infer exp) + (define-values (env constraints) + (make-constraints* exp)) + (assq-ref (solve-constraints env constraints) exp)) + +(define (test-equal a b) + (unless (equal? a b) + (error "fail:" a b))) + +(test-equal (infer 6) int) +(test-equal (infer #t) bool) +(test-equal (infer #f) bool) +(test-equal (false-if-exception (infer 'x)) #f) +(test-equal (infer '(lambda (x) (not x))) (make-function-type bool bool)) +(test-equal (infer '((lambda (x) x) 6)) int) +(test-equal (infer '((lambda (x) (if (not #t) (add1 x) (sub1 x))) 1)) int) +(test-equal (false-if-exception (infer '((lambda (x) (if #t 1 x)) #f))) #f) |