diff options
-rw-r--r-- | infer.scm | 255 |
1 files changed, 212 insertions, 43 deletions
@@ -1,28 +1,112 @@ -(use-modules (ice-9 match) +(use-modules (ice-9 format) + (ice-9 match) (srfi srfi-1) - (srfi srfi-9)) + (srfi srfi-9) + (srfi srfi-9 gnu)) -(define-record-type <named-type> - (make-named-type name) - named-type? - (name named-type-name)) +(define-record-type <primitive-type> + (make-primitive-type name) + primitive-type? + (name primitive-type-name)) + +(define (display-primitive-type type port) + (format port "#<primitive-type ~a>" (primitive-type-name type))) +(set-record-type-printer! <primitive-type> display-primitive-type) + +(define-record-type <struct-type> + (make-struct-type name fields) + struct-type? + (name struct-type-name) + (fields struct-type-fields)) + +(define (display-struct-type type port) + (format port "#<struct-type ~a>" (struct-type-name type))) +(set-record-type-printer! <struct-type> display-struct-type) + +(define-record-type <struct-ref-type> + (make-struct-ref-type struct field) + struct-ref-type? + (struct struct-ref-type-struct) + (field struct-ref-type-field)) + +(define-record-type <array-type> + (make-array-type name type) + array-type? + (name array-type-name) + (type array-type-type)) + +(define (display-array-type type port) + (format port "#<array-type ~a>" (array-type-name type))) +(set-record-type-printer! <array-type> display-array-type) + +(define-record-type <array-ref-type> + (make-array-ref-type array index) + array-ref-type? + (array array-ref-type-array) + (index array-ref-type-index)) (define-record-type <variable-type> (make-variable-type name) variable-type? (name variable-type-name)) +(define (display-variable-type type port) + (format port "#<variable-type ~a>" (variable-type-name type))) +(set-record-type-printer! <variable-type> display-variable-type) + (define-record-type <function-type> (make-function-type from to) function-type? (from function-type-from) (to function-type-to)) -(define (make-function-type* . types) +(define (display-function-type type port) + (format port "#<function-type ~a -> ~a>" + (function-type-from type) + (function-type-to type))) +(set-record-type-printer! <function-type> display-function-type) + +(define (named-type? type) + (or (primitive-type? type) + (struct-type? type) + (array-type? type))) + +(define (function . types) (reduce-right make-function-type #f types)) -(define int (make-named-type 'int)) -(define bool (make-named-type 'bool)) +(define int (make-primitive-type 'int)) +(define float (make-primitive-type 'float)) +(define bool (make-primitive-type 'bool)) +(define vec2 + (make-struct-type 'vec2 + `((x . ,float) + (y . ,float)))) +(define vec3 + (make-struct-type 'vec3 + `((x . ,float) + (y . ,float) + (z . ,float)))) +(define vec4 + (make-struct-type 'vec4 + `((x . ,float) + (y . ,float) + (z . ,float) + (w . ,float)))) +(define mat4 (make-array-type 'mat4 (make-array-type 'mat4-row float))) + +(define %default-env + `((not . ,(make-function-type bool bool)) + (+ . ,(function int int int)) + (* . ,(function int int int)) + (= . ,(function int int bool)) + (< . ,(function int int bool)) + (<= . ,(function int int bool)) + (> . ,(function int int bool)) + (>= . ,(function int int bool)) + (vec2 . ,(function float float vec2)) + (vec3 . ,(function float float float vec3)) + (vec4 . ,(function float float float float vec4)) + (mat4 . ,(function mat4)))) (define unique-counter (make-parameter 0)) @@ -35,7 +119,7 @@ (string->symbol (format #f "~a~a" prefix (unique-number)))) -(define (make-fresh-variable-type) +(define (fresh-variable) (make-variable-type (unique-identifier))) (define (substitute-type subs type) @@ -43,12 +127,50 @@ ((named-type? type) type) ((variable-type? type) + ;; Substitute variable with its actual type, or return the + ;; variable if its type is still unknown. (or (assq-ref subs type) type)) + ((struct-ref-type? type) + (let* ((struct (struct-ref-type-struct type)) + (field (struct-ref-type-field type)) + (struct* (substitute-type subs struct))) + (cond + ;; Substituted type is a struct type, so we can resolve the + ;; reference. + ((struct-type? struct*) + (or (assq-ref (struct-type-fields struct*) field) + (error "no such field in struct" struct* field))) + ;; Substitution didn't change anything, return the original + ;; type. + ((eq? struct struct*) + type) + ;; Substitution hasn't yet produced a struct to reference. + (else + (make-struct-ref-type struct* field))))) + ((array-ref-type? type) + (let* ((array (array-ref-type-array type)) + (index (array-ref-type-index type)) + (array* (substitute-type subs array)) + (index* (substitute-type subs index))) + (cond + ;; Substituted type is an array type, so we can subsitute the + ;; type of the array elements. + ((array-type? array*) + (array-type-type array*)) + ;; Substitution didn't change anything, return the original + ;; type. + ((and (eq? array array*) (eq? index index*)) + type) + ;; Substitution hasn't yet produced an array type. + (else + (make-array-ref-type array* index*))))) ((function-type? type) (let* ((from (function-type-from type)) (to (function-type-to type)) (from* (substitute-type subs from)) (to* (substitute-type subs to))) + ;; If no substitution has occurred, return the original type and + ;; avoid some unnecessary allocation. (if (and (eq? from from*) (eq? to to*)) type (make-function-type from* to*)))))) @@ -78,22 +200,51 @@ (match exp ((and (? number?) (? exact-integer?)) (values (list (cons exp int)) '())) + ((and (? number?) (? inexact?)) + (values (list (cons exp float)) '())) ((? boolean?) (values (list (cons exp bool)) '())) ((? symbol?) - (values (list (cons exp - (or (assq-ref env exp) - (error "unbound variable" exp)))) - '())) - (('lambda (arg) body) - (define arg-type (make-fresh-variable-type)) + (if (assq-ref env exp) + (values '() '()) + (error "unbound variable" exp))) + (('lambda ((? symbol? args) ...) body) + (define arg-vars (map (lambda (_arg) (fresh-variable)) args)) + (define env* (append (map cons args arg-vars) env)) (define-values (body-env body-constraints) - (make-constraints body (alist-cons arg arg-type env))) - (values (cons* (cons exp (make-function-type arg-type - (assq-ref body-env body))) - (cons arg arg-type) - body-env) + (make-constraints body env*)) + (define env** (append body-env env*)) + (define body-type (assq-ref env** body)) + (define lambda-type + (fold-right make-function-type body-type arg-vars)) + (values (cons (cons exp lambda-type) + env**) body-constraints)) + (('let ((vars vals) ...) body) + (define-values (%value-env %value-constraints) + (unzip2 + (map (lambda (value) + (call-with-values + (lambda () + (make-constraints value env)) + list)) + vals))) + (define value-env (concatenate %value-env)) + (define value-constraints (concatenate %value-constraints)) + (define var-types + (map (lambda (_var) (fresh-variable)) vars)) + (define env* (append (map cons vars var-types) env)) + (define-values (body-env body-constraints) + (make-constraints body env*)) + (define env** (append body-env env*)) + (values (append (list (cons exp (assq-ref env** body))) + value-env + env**) + (append (map (lambda (var-type value) + (cons var-type (assq-ref value-env value))) + var-types vals) + value-constraints + body-constraints))) (('if test consequent alternate) (define-values (test-env test-constraints) (make-constraints test env)) @@ -112,47 +263,58 @@ test-constraints consequent-constraints alternate-constraints))) + (('-> struct (? symbol? field)) + (define-values (struct-env struct-constraints) + (make-constraints struct env)) + (define ref-type + (make-struct-ref-type (assq-ref struct-env struct) + field)) + (values (append (list (cons exp ref-type)) + struct-env + env) + struct-constraints)) + (('@ array index) + (define-values (array-env array-constraints) + (make-constraints array env)) + (define-values (index-env index-constraints) + (make-constraints index env)) + (define ref-type + (make-array-ref-type (assq-ref array-env array) + (assq-ref index-env index))) + (values (append (list (cons exp ref-type) + array-env + index-env + env)) + (append (list (cons (assq-ref index-env index) int)) + array-constraints + index-constraints))) ((proc args ...) - (define-values (%arg-envs %arg-constraints) + (define-values (%arg-env %arg-constraints) (unzip2 (map (lambda (arg) (call-with-values (lambda () (make-constraints arg env)) list)) args))) - (define arg-env (concatenate %arg-envs)) + (define arg-env (append (concatenate %arg-env) env)) (define arg-constraints (concatenate %arg-constraints)) (define-values (proc-env proc-constraints) (make-constraints proc env)) - (define return-type (make-fresh-variable-type)) + (define return-type (fresh-variable)) (define call-type (fold-right make-function-type return-type (map (lambda (arg) (assq-ref arg-env arg)) args))) + (define env* (append proc-env arg-env)) (values (append (list (cons exp return-type)) - proc-env - arg-env) - (append (list (cons (assq-ref proc-env proc) call-type)) + env*) + (append (list (cons (assq-ref env* proc) call-type)) proc-constraints arg-constraints))) (_ (error "invalid expression" exp)))) -(define %default-env - `((not . ,(make-function-type bool bool)) - (add1 . ,(make-function-type int int)) - (sub1 . ,(make-function-type int int)) - (+ . ,(make-function-type* int int int)) - (- . ,(make-function-type* int int int)) - (* . ,(make-function-type* int int int)) - (/ . ,(make-function-type* int int int)) - (= . ,(make-function-type* int int bool)) - (< . ,(make-function-type* int int bool)) - (<= . ,(make-function-type* int int bool)) - (> . ,(make-function-type* int int bool)) - (>= . ,(make-function-type* int int bool)))) - (define (make-constraints* exp) (parameterize ((unique-counter 0)) (define-values (env constraints) @@ -171,7 +333,7 @@ (else (list (cons var type))))) (cond - ;; A and B are the same simple type (like int or bool.) + ;; A and B are the same primitive or struct type. ((and (named-type? a) (named-type? b) (eq? a b)) '()) ;; A or B is a type variable. @@ -220,7 +382,14 @@ (test-equal (false-if-exception (infer 'x)) #f) (test-equal (infer '(lambda (x) (not x))) (make-function-type bool bool)) (test-equal (infer '((lambda (x) x) 6)) int) -(test-equal (infer '((lambda (x) (if (not #t) (add1 x) (sub1 x))) 1)) int) +(test-equal (infer '((lambda (x) (if (not #t) (+ x 1) (+ x 2))) 1)) int) (test-equal (false-if-exception (infer '((lambda (x) (if #t 1 x)) #f))) #f) (test-equal (infer '((lambda (x) (+ 1 x)) 2)) int) (test-equal (infer '((lambda (x) (= 1 x)) 2)) bool) +(test-equal (infer '(vec2 1.0 2.0)) vec2) +(test-equal (infer '(-> (vec2 1.0 2.0) x)) float) +(test-equal (infer '((lambda (x y) (+ 1 2)) 2 3)) int) +(test-equal (infer '(let ((x 1) + (f (lambda (x) x))) + (* (+ (f x) 1) (f x)))) + int) |