(use-modules (ice-9 format) (ice-9 match) (srfi srfi-1) (srfi srfi-9) (srfi srfi-9 gnu)) (define-record-type (make-primitive-type name) primitive-type? (name primitive-type-name)) (define (display-primitive-type type port) (format port "#" (primitive-type-name type))) (set-record-type-printer! display-primitive-type) (define-record-type (make-struct-type name fields) struct-type? (name struct-type-name) (fields struct-type-fields)) (define (display-struct-type type port) (format port "#" (struct-type-name type))) (set-record-type-printer! display-struct-type) (define-record-type (make-struct-ref-type struct field) struct-ref-type? (struct struct-ref-type-struct) (field struct-ref-type-field)) (define-record-type (make-array-type name type) array-type? (name array-type-name) (type array-type-type)) (define (display-array-type type port) (format port "#" (array-type-name type))) (set-record-type-printer! display-array-type) (define-record-type (make-array-ref-type array index) array-ref-type? (array array-ref-type-array) (index array-ref-type-index)) (define-record-type (make-variable-type name) variable-type? (name variable-type-name)) (define (display-variable-type type port) (format port "#" (variable-type-name type))) (set-record-type-printer! display-variable-type) (define-record-type (make-function-type from to) function-type? (from function-type-from) (to function-type-to)) (define (display-function-type type port) (format port "# ~a>" (function-type-from type) (function-type-to type))) (set-record-type-printer! display-function-type) (define (named-type? type) (or (primitive-type? type) (struct-type? type) (array-type? type))) (define (function . types) (reduce-right make-function-type #f types)) (define int (make-primitive-type 'int)) (define float (make-primitive-type 'float)) (define bool (make-primitive-type 'bool)) (define vec2 (make-struct-type 'vec2 `((x . ,float) (y . ,float)))) (define vec3 (make-struct-type 'vec3 `((x . ,float) (y . ,float) (z . ,float)))) (define vec4 (make-struct-type 'vec4 `((x . ,float) (y . ,float) (z . ,float) (w . ,float)))) (define mat4 (make-array-type 'mat4 (make-array-type 'mat4-row float))) (define %default-env `((not . ,(make-function-type bool bool)) (+ . ,(function int int int)) (* . ,(function int int int)) (= . ,(function int int bool)) (< . ,(function int int bool)) (<= . ,(function int int bool)) (> . ,(function int int bool)) (>= . ,(function int int bool)) (vec2 . ,(function float float vec2)) (vec3 . ,(function float float float vec3)) (vec4 . ,(function float float float float vec4)) (mat4 . ,(function mat4)))) (define unique-counter (make-parameter 0)) (define (unique-number) (let ((n (unique-counter))) (unique-counter (+ n 1)) n)) (define* (unique-identifier #:optional (prefix 'T)) (string->symbol (format #f "~a~a" prefix (unique-number)))) (define (fresh-variable) (make-variable-type (unique-identifier))) (define (substitute-type subs type) (cond ((named-type? type) type) ((variable-type? type) ;; Substitute variable with its actual type, or return the ;; variable if its type is still unknown. (or (assq-ref subs type) type)) ((struct-ref-type? type) (let* ((struct (struct-ref-type-struct type)) (field (struct-ref-type-field type)) (struct* (substitute-type subs struct))) (cond ;; Substituted type is a struct type, so we can resolve the ;; reference. ((struct-type? struct*) (or (assq-ref (struct-type-fields struct*) field) (error "no such field in struct" struct* field))) ;; Substitution didn't change anything, return the original ;; type. ((eq? struct struct*) type) ;; Substitution hasn't yet produced a struct to reference. (else (make-struct-ref-type struct* field))))) ((array-ref-type? type) (let* ((array (array-ref-type-array type)) (index (array-ref-type-index type)) (array* (substitute-type subs array)) (index* (substitute-type subs index))) (cond ;; Substituted type is an array type, so we can subsitute the ;; type of the array elements. ((array-type? array*) (array-type-type array*)) ;; Substitution didn't change anything, return the original ;; type. ((and (eq? array array*) (eq? index index*)) type) ;; Substitution hasn't yet produced an array type. (else (make-array-ref-type array* index*))))) ((function-type? type) (let* ((from (function-type-from type)) (to (function-type-to type)) (from* (substitute-type subs from)) (to* (substitute-type subs to))) ;; If no substitution has occurred, return the original type and ;; avoid some unnecessary allocation. (if (and (eq? from from*) (eq? to to*)) type (make-function-type from* to*)))))) (define (substitute-env subs env) (map (match-lambda ((name . type) (cons name (substitute-type subs type)))) env)) (define (substitute-constraints subs constraints) (map (match-lambda ((a . b) (cons (substitute-type subs a) (substitute-type subs b)))) constraints)) (define (contains? a b) (cond ((variable-type? a) (eq? a b)) ((named-type? a) #f) ((function-type? a) (or (contains? (function-type-from a) b) (contains? (function-type-to a) b))))) (define (make-constraints exp env) (match exp ((and (? number?) (? exact-integer?)) (values (list (cons exp int)) '())) ((and (? number?) (? inexact?)) (values (list (cons exp float)) '())) ((? boolean?) (values (list (cons exp bool)) '())) ((? symbol?) (if (assq-ref env exp) (values '() '()) (error "unbound variable" exp))) (('lambda ((? symbol? args) ...) body) (define arg-vars (map (lambda (_arg) (fresh-variable)) args)) (define env* (append (map cons args arg-vars) env)) (define-values (body-env body-constraints) (make-constraints body env*)) (define env** (append body-env env*)) (define body-type (assq-ref env** body)) (define lambda-type (fold-right make-function-type body-type arg-vars)) (values (cons (cons exp lambda-type) env**) body-constraints)) (('let ((vars vals) ...) body) (define-values (%value-env %value-constraints) (unzip2 (map (lambda (value) (call-with-values (lambda () (make-constraints value env)) list)) vals))) (define value-env (concatenate %value-env)) (define value-constraints (concatenate %value-constraints)) (define var-types (map (lambda (_var) (fresh-variable)) vars)) (define env* (append (map cons vars var-types) env)) (define-values (body-env body-constraints) (make-constraints body env*)) (define env** (append body-env env*)) (values (append (list (cons exp (assq-ref env** body))) value-env env**) (append (map (lambda (var-type value) (cons var-type (assq-ref value-env value))) var-types vals) value-constraints body-constraints))) (('if test consequent alternate) (define-values (test-env test-constraints) (make-constraints test env)) (define-values (consequent-env consequent-constraints) (make-constraints consequent env)) (define-values (alternate-env alternate-constraints) (make-constraints alternate env)) (values (cons (cons exp (assq-ref consequent-env consequent)) (append test-env consequent-env alternate-env env)) (append (list (cons (assq-ref test-env test) bool) (cons (assq-ref consequent-env consequent) (assq-ref alternate-env alternate))) test-constraints consequent-constraints alternate-constraints))) (('-> struct (? symbol? field)) (define-values (struct-env struct-constraints) (make-constraints struct env)) (define ref-type (make-struct-ref-type (assq-ref struct-env struct) field)) (values (append (list (cons exp ref-type)) struct-env env) struct-constraints)) (('@ array index) (define-values (array-env array-constraints) (make-constraints array env)) (define-values (index-env index-constraints) (make-constraints index env)) (define ref-type (make-array-ref-type (assq-ref array-env array) (assq-ref index-env index))) (values (append (list (cons exp ref-type) array-env index-env env)) (append (list (cons (assq-ref index-env index) int)) array-constraints index-constraints))) ((proc args ...) (define-values (%arg-env %arg-constraints) (unzip2 (map (lambda (arg) (call-with-values (lambda () (make-constraints arg env)) list)) args))) (define arg-env (append (concatenate %arg-env) env)) (define arg-constraints (concatenate %arg-constraints)) (define-values (proc-env proc-constraints) (make-constraints proc env)) (define return-type (fresh-variable)) (define call-type (fold-right make-function-type return-type (map (lambda (arg) (assq-ref arg-env arg)) args))) (define env* (append proc-env arg-env)) (values (append (list (cons exp return-type)) env*) (append (list (cons (assq-ref env* proc) call-type)) proc-constraints arg-constraints))) (_ (error "invalid expression" exp)))) (define (make-constraints* exp) (parameterize ((unique-counter 0)) (define-values (env constraints) (make-constraints exp %default-env)) (values (delete-duplicates env) constraints))) (define (unify a b) (define (sub-var var type) (cond ;; Type is also a variable, so we can't do anything. ((eq? var type) '()) ;; Variable appears within type, which is not allowed. ((contains? type var) (error "circular reference" var type)) (else (list (cons var type))))) (cond ;; A and B are the same primitive or struct type. ((and (named-type? a) (named-type? b) (eq? a b)) '()) ;; A or B is a type variable. ((variable-type? a) (sub-var a b)) ((variable-type? b) (sub-var b a)) ;; A and B are function types. ((and (function-type? a) (function-type? b)) (let* ((a-subs (unify (function-type-from a) (function-type-from b))) (b-subs (unify (substitute-type a-subs (function-type-to a)) (substitute-type a-subs (function-type-to b))))) (append a-subs b-subs))) ;; Oh no. (else (error "type mismtach" a b)))) ;; Successively transform the type environment by applying ;; constraints. If there are no type mismatches or other errors then ;; a new type environment in which all type variables have been ;; removed is returned. (define (solve-constraints env constraints) (match constraints (() env) (((a . b) . rest) ;; First, attempt to unify the 2 types in the constraint and get ;; the substitutions that unification creates. Subsitutions need ;; to be applied to the type environment *and* the remaining ;; constraints. (let* ((new-subs (unify a b))) (solve-constraints (substitute-env new-subs env) (substitute-constraints new-subs rest)))))) (define (infer exp) (define-values (env constraints) (make-constraints* exp)) (assq-ref (solve-constraints env constraints) exp)) (define (test-equal a b) (unless (equal? a b) (error "fail:" a b))) (test-equal (infer 6) int) (test-equal (infer #t) bool) (test-equal (infer #f) bool) (test-equal (false-if-exception (infer 'x)) #f) (test-equal (infer '(lambda (x) (not x))) (make-function-type bool bool)) (test-equal (infer '((lambda (x) x) 6)) int) (test-equal (infer '((lambda (x) (if (not #t) (+ x 1) (+ x 2))) 1)) int) (test-equal (false-if-exception (infer '((lambda (x) (if #t 1 x)) #f))) #f) (test-equal (infer '((lambda (x) (+ 1 x)) 2)) int) (test-equal (infer '((lambda (x) (= 1 x)) 2)) bool) (test-equal (infer '(vec2 1.0 2.0)) vec2) (test-equal (infer '(-> (vec2 1.0 2.0) x)) float) (test-equal (infer '((lambda (x y) (+ 1 2)) 2 3)) int) (test-equal (infer '(let ((x 1) (f (lambda (x) x))) (* (+ (f x) 1) (f x)))) int)