summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--infer.scm255
1 files changed, 212 insertions, 43 deletions
diff --git a/infer.scm b/infer.scm
index b1e468c..81b8eee 100644
--- a/infer.scm
+++ b/infer.scm
@@ -1,28 +1,112 @@
-(use-modules (ice-9 match)
+(use-modules (ice-9 format)
+ (ice-9 match)
(srfi srfi-1)
- (srfi srfi-9))
+ (srfi srfi-9)
+ (srfi srfi-9 gnu))
-(define-record-type <named-type>
- (make-named-type name)
- named-type?
- (name named-type-name))
+(define-record-type <primitive-type>
+ (make-primitive-type name)
+ primitive-type?
+ (name primitive-type-name))
+
+(define (display-primitive-type type port)
+ (format port "#<primitive-type ~a>" (primitive-type-name type)))
+(set-record-type-printer! <primitive-type> display-primitive-type)
+
+(define-record-type <struct-type>
+ (make-struct-type name fields)
+ struct-type?
+ (name struct-type-name)
+ (fields struct-type-fields))
+
+(define (display-struct-type type port)
+ (format port "#<struct-type ~a>" (struct-type-name type)))
+(set-record-type-printer! <struct-type> display-struct-type)
+
+(define-record-type <struct-ref-type>
+ (make-struct-ref-type struct field)
+ struct-ref-type?
+ (struct struct-ref-type-struct)
+ (field struct-ref-type-field))
+
+(define-record-type <array-type>
+ (make-array-type name type)
+ array-type?
+ (name array-type-name)
+ (type array-type-type))
+
+(define (display-array-type type port)
+ (format port "#<array-type ~a>" (array-type-name type)))
+(set-record-type-printer! <array-type> display-array-type)
+
+(define-record-type <array-ref-type>
+ (make-array-ref-type array index)
+ array-ref-type?
+ (array array-ref-type-array)
+ (index array-ref-type-index))
(define-record-type <variable-type>
(make-variable-type name)
variable-type?
(name variable-type-name))
+(define (display-variable-type type port)
+ (format port "#<variable-type ~a>" (variable-type-name type)))
+(set-record-type-printer! <variable-type> display-variable-type)
+
(define-record-type <function-type>
(make-function-type from to)
function-type?
(from function-type-from)
(to function-type-to))
-(define (make-function-type* . types)
+(define (display-function-type type port)
+ (format port "#<function-type ~a -> ~a>"
+ (function-type-from type)
+ (function-type-to type)))
+(set-record-type-printer! <function-type> display-function-type)
+
+(define (named-type? type)
+ (or (primitive-type? type)
+ (struct-type? type)
+ (array-type? type)))
+
+(define (function . types)
(reduce-right make-function-type #f types))
-(define int (make-named-type 'int))
-(define bool (make-named-type 'bool))
+(define int (make-primitive-type 'int))
+(define float (make-primitive-type 'float))
+(define bool (make-primitive-type 'bool))
+(define vec2
+ (make-struct-type 'vec2
+ `((x . ,float)
+ (y . ,float))))
+(define vec3
+ (make-struct-type 'vec3
+ `((x . ,float)
+ (y . ,float)
+ (z . ,float))))
+(define vec4
+ (make-struct-type 'vec4
+ `((x . ,float)
+ (y . ,float)
+ (z . ,float)
+ (w . ,float))))
+(define mat4 (make-array-type 'mat4 (make-array-type 'mat4-row float)))
+
+(define %default-env
+ `((not . ,(make-function-type bool bool))
+ (+ . ,(function int int int))
+ (* . ,(function int int int))
+ (= . ,(function int int bool))
+ (< . ,(function int int bool))
+ (<= . ,(function int int bool))
+ (> . ,(function int int bool))
+ (>= . ,(function int int bool))
+ (vec2 . ,(function float float vec2))
+ (vec3 . ,(function float float float vec3))
+ (vec4 . ,(function float float float float vec4))
+ (mat4 . ,(function mat4))))
(define unique-counter (make-parameter 0))
@@ -35,7 +119,7 @@
(string->symbol
(format #f "~a~a" prefix (unique-number))))
-(define (make-fresh-variable-type)
+(define (fresh-variable)
(make-variable-type (unique-identifier)))
(define (substitute-type subs type)
@@ -43,12 +127,50 @@
((named-type? type)
type)
((variable-type? type)
+ ;; Substitute variable with its actual type, or return the
+ ;; variable if its type is still unknown.
(or (assq-ref subs type) type))
+ ((struct-ref-type? type)
+ (let* ((struct (struct-ref-type-struct type))
+ (field (struct-ref-type-field type))
+ (struct* (substitute-type subs struct)))
+ (cond
+ ;; Substituted type is a struct type, so we can resolve the
+ ;; reference.
+ ((struct-type? struct*)
+ (or (assq-ref (struct-type-fields struct*) field)
+ (error "no such field in struct" struct* field)))
+ ;; Substitution didn't change anything, return the original
+ ;; type.
+ ((eq? struct struct*)
+ type)
+ ;; Substitution hasn't yet produced a struct to reference.
+ (else
+ (make-struct-ref-type struct* field)))))
+ ((array-ref-type? type)
+ (let* ((array (array-ref-type-array type))
+ (index (array-ref-type-index type))
+ (array* (substitute-type subs array))
+ (index* (substitute-type subs index)))
+ (cond
+ ;; Substituted type is an array type, so we can subsitute the
+ ;; type of the array elements.
+ ((array-type? array*)
+ (array-type-type array*))
+ ;; Substitution didn't change anything, return the original
+ ;; type.
+ ((and (eq? array array*) (eq? index index*))
+ type)
+ ;; Substitution hasn't yet produced an array type.
+ (else
+ (make-array-ref-type array* index*)))))
((function-type? type)
(let* ((from (function-type-from type))
(to (function-type-to type))
(from* (substitute-type subs from))
(to* (substitute-type subs to)))
+ ;; If no substitution has occurred, return the original type and
+ ;; avoid some unnecessary allocation.
(if (and (eq? from from*) (eq? to to*))
type
(make-function-type from* to*))))))
@@ -78,22 +200,51 @@
(match exp
((and (? number?) (? exact-integer?))
(values (list (cons exp int)) '()))
+ ((and (? number?) (? inexact?))
+ (values (list (cons exp float)) '()))
((? boolean?)
(values (list (cons exp bool)) '()))
((? symbol?)
- (values (list (cons exp
- (or (assq-ref env exp)
- (error "unbound variable" exp))))
- '()))
- (('lambda (arg) body)
- (define arg-type (make-fresh-variable-type))
+ (if (assq-ref env exp)
+ (values '() '())
+ (error "unbound variable" exp)))
+ (('lambda ((? symbol? args) ...) body)
+ (define arg-vars (map (lambda (_arg) (fresh-variable)) args))
+ (define env* (append (map cons args arg-vars) env))
(define-values (body-env body-constraints)
- (make-constraints body (alist-cons arg arg-type env)))
- (values (cons* (cons exp (make-function-type arg-type
- (assq-ref body-env body)))
- (cons arg arg-type)
- body-env)
+ (make-constraints body env*))
+ (define env** (append body-env env*))
+ (define body-type (assq-ref env** body))
+ (define lambda-type
+ (fold-right make-function-type body-type arg-vars))
+ (values (cons (cons exp lambda-type)
+ env**)
body-constraints))
+ (('let ((vars vals) ...) body)
+ (define-values (%value-env %value-constraints)
+ (unzip2
+ (map (lambda (value)
+ (call-with-values
+ (lambda ()
+ (make-constraints value env))
+ list))
+ vals)))
+ (define value-env (concatenate %value-env))
+ (define value-constraints (concatenate %value-constraints))
+ (define var-types
+ (map (lambda (_var) (fresh-variable)) vars))
+ (define env* (append (map cons vars var-types) env))
+ (define-values (body-env body-constraints)
+ (make-constraints body env*))
+ (define env** (append body-env env*))
+ (values (append (list (cons exp (assq-ref env** body)))
+ value-env
+ env**)
+ (append (map (lambda (var-type value)
+ (cons var-type (assq-ref value-env value)))
+ var-types vals)
+ value-constraints
+ body-constraints)))
(('if test consequent alternate)
(define-values (test-env test-constraints)
(make-constraints test env))
@@ -112,47 +263,58 @@
test-constraints
consequent-constraints
alternate-constraints)))
+ (('-> struct (? symbol? field))
+ (define-values (struct-env struct-constraints)
+ (make-constraints struct env))
+ (define ref-type
+ (make-struct-ref-type (assq-ref struct-env struct)
+ field))
+ (values (append (list (cons exp ref-type))
+ struct-env
+ env)
+ struct-constraints))
+ (('@ array index)
+ (define-values (array-env array-constraints)
+ (make-constraints array env))
+ (define-values (index-env index-constraints)
+ (make-constraints index env))
+ (define ref-type
+ (make-array-ref-type (assq-ref array-env array)
+ (assq-ref index-env index)))
+ (values (append (list (cons exp ref-type)
+ array-env
+ index-env
+ env))
+ (append (list (cons (assq-ref index-env index) int))
+ array-constraints
+ index-constraints)))
((proc args ...)
- (define-values (%arg-envs %arg-constraints)
+ (define-values (%arg-env %arg-constraints)
(unzip2
(map (lambda (arg)
(call-with-values (lambda ()
(make-constraints arg env))
list))
args)))
- (define arg-env (concatenate %arg-envs))
+ (define arg-env (append (concatenate %arg-env) env))
(define arg-constraints (concatenate %arg-constraints))
(define-values (proc-env proc-constraints)
(make-constraints proc env))
- (define return-type (make-fresh-variable-type))
+ (define return-type (fresh-variable))
(define call-type
(fold-right make-function-type return-type
(map (lambda (arg)
(assq-ref arg-env arg))
args)))
+ (define env* (append proc-env arg-env))
(values (append (list (cons exp return-type))
- proc-env
- arg-env)
- (append (list (cons (assq-ref proc-env proc) call-type))
+ env*)
+ (append (list (cons (assq-ref env* proc) call-type))
proc-constraints
arg-constraints)))
(_
(error "invalid expression" exp))))
-(define %default-env
- `((not . ,(make-function-type bool bool))
- (add1 . ,(make-function-type int int))
- (sub1 . ,(make-function-type int int))
- (+ . ,(make-function-type* int int int))
- (- . ,(make-function-type* int int int))
- (* . ,(make-function-type* int int int))
- (/ . ,(make-function-type* int int int))
- (= . ,(make-function-type* int int bool))
- (< . ,(make-function-type* int int bool))
- (<= . ,(make-function-type* int int bool))
- (> . ,(make-function-type* int int bool))
- (>= . ,(make-function-type* int int bool))))
-
(define (make-constraints* exp)
(parameterize ((unique-counter 0))
(define-values (env constraints)
@@ -171,7 +333,7 @@
(else
(list (cons var type)))))
(cond
- ;; A and B are the same simple type (like int or bool.)
+ ;; A and B are the same primitive or struct type.
((and (named-type? a) (named-type? b) (eq? a b))
'())
;; A or B is a type variable.
@@ -220,7 +382,14 @@
(test-equal (false-if-exception (infer 'x)) #f)
(test-equal (infer '(lambda (x) (not x))) (make-function-type bool bool))
(test-equal (infer '((lambda (x) x) 6)) int)
-(test-equal (infer '((lambda (x) (if (not #t) (add1 x) (sub1 x))) 1)) int)
+(test-equal (infer '((lambda (x) (if (not #t) (+ x 1) (+ x 2))) 1)) int)
(test-equal (false-if-exception (infer '((lambda (x) (if #t 1 x)) #f))) #f)
(test-equal (infer '((lambda (x) (+ 1 x)) 2)) int)
(test-equal (infer '((lambda (x) (= 1 x)) 2)) bool)
+(test-equal (infer '(vec2 1.0 2.0)) vec2)
+(test-equal (infer '(-> (vec2 1.0 2.0) x)) float)
+(test-equal (infer '((lambda (x y) (+ 1 2)) 2 3)) int)
+(test-equal (infer '(let ((x 1)
+ (f (lambda (x) x)))
+ (* (+ (f x) 1) (f x))))
+ int)