summaryrefslogtreecommitdiff
path: root/infer.scm
blob: c9da6ef74f54e619503d5a8d76d4496792be542b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
(use-modules (ice-9 match)
             (srfi srfi-1)
             (srfi srfi-9))

(define-record-type <named-type>
  (make-named-type name)
  named-type?
  (name named-type-name))

(define-record-type <variable-type>
  (make-variable-type name)
  variable-type?
  (name variable-type-name))

(define-record-type <function-type>
  (make-function-type from to)
  function-type?
  (from function-type-from)
  (to function-type-to))

(define int (make-named-type 'int))
(define bool (make-named-type 'bool))

(define unique-counter (make-parameter 0))

(define (unique-number)
  (let ((n (unique-counter)))
    (unique-counter (+ n 1))
    n))

(define* (unique-identifier #:optional (prefix 'T))
  (string->symbol
   (format #f "~a~a" prefix (unique-number))))

(define (make-fresh-variable-type)
  (make-variable-type (unique-identifier)))

(define (substitute-type subs type)
  (cond
   ((named-type? type)
    type)
   ((variable-type? type)
    (or (assq-ref subs type) type))
   ((function-type? type)
    (let* ((from (function-type-from type))
           (to (function-type-to type))
           (from* (substitute-type subs from))
           (to* (substitute-type subs to)))
      (if (and (eq? from from*) (eq? to to*))
          type
          (make-function-type from* to*))))))

(define (substitute-env subs env)
  (map (match-lambda
         ((name . type)
          (cons name (substitute-type subs type))))
       env))

(define (substitute-constraints subs constraints)
  (map (match-lambda
         ((a . b)
          (cons (substitute-type subs a) (substitute-type subs b))))
       constraints))

(define (contains? a b)
  (cond
   ((variable-type? a)
    (eq? a b))
   ((named-type? a) #f)
   ((function-type? a)
    (or (contains? (function-type-from a) b)
        (contains? (function-type-to a) b)))))

(define (make-constraints exp env)
  (match exp
    ((and (? number?) (? exact-integer?))
     (values (list (cons exp int)) '()))
    ((? boolean?)
     (values (list (cons exp bool)) '()))
    ((? symbol?)
     (values (list (cons exp
                         (or (assq-ref env exp)
                             (error "unbound variable" exp))))
             '()))
    (('lambda (arg) body)
     (define arg-type (make-fresh-variable-type))
     (define-values (body-env body-constraints)
       (make-constraints body (alist-cons arg arg-type env)))
     (values (cons* (cons exp (make-function-type arg-type
                                                  (assq-ref body-env body)))
                    (cons arg arg-type)
                    body-env)
             body-constraints))
    (('if test consequent alternate)
     (define-values (test-env test-constraints)
       (make-constraints test env))
     (define-values (consequent-env consequent-constraints)
       (make-constraints consequent env))
     (define-values (alternate-env alternate-constraints)
       (make-constraints alternate env))
     (values (cons (cons exp (assq-ref consequent-env consequent))
                   (append test-env
                           consequent-env
                           alternate-env
                           env))
             (append (list (cons (assq-ref test-env test) bool)
                           (cons (assq-ref consequent-env consequent)
                                 (assq-ref alternate-env alternate)))
                     test-constraints
                     consequent-constraints
                     alternate-constraints)))
    ((proc args ...)
     (define-values (%arg-envs %arg-constraints)
       (unzip2
        (map (lambda (arg)
               (call-with-values (lambda ()
                                   (make-constraints arg env))
                 list))
             args)))
     (define arg-env (concatenate %arg-envs))
     (define arg-constraints (concatenate %arg-constraints))
     (define-values (proc-env proc-constraints)
       (make-constraints proc env))
     ;; (define-values (arg-env arg-constraints)
     ;;   (make-constraints arg env))
     (define return-type (make-fresh-variable-type))
     (define call-type (make-function-type (assq-ref arg-env (car args))
                                           return-type))
     (values (append (list (cons exp return-type))
                     proc-env
                     arg-env)
             (append (list (cons (assq-ref proc-env proc) call-type))
                     proc-constraints
                     arg-constraints)))
    (_
     (error "invalid expression" exp))))

(define %default-env
  `((not . ,(make-function-type bool bool))
    (add1 . ,(make-function-type int int))
    (sub1 . ,(make-function-type int int))))

(define (make-constraints* exp)
  (parameterize ((unique-counter 0))
    (define-values (env constraints)
      (make-constraints exp %default-env))
    (values (delete-duplicates env) constraints)))

(define (unify a b)
  (define (sub-var var type)
    (cond
     ;; Type is also a variable, so we can't do anything.
     ((eq? var type)
      '())
     ;; Variable appears within type, which is not allowed.
     ((contains? type var)
      (error "circular reference" var type))
     (else
      (list (cons var type)))))
  (cond
   ;; A and B are the same simple type (like int or bool.)
   ((and (named-type? a) (named-type? b) (eq? a b))
    '())
   ;; A or B is a type variable.
   ((variable-type? a)
    (sub-var a b))
   ((variable-type? b)
    (sub-var b a))
   ;; A and B are function types.
   ((and (function-type? a) (function-type? b))
    (let* ((a-subs (unify (function-type-from a) (function-type-from b)))
           (b-subs (unify (substitute-type a-subs (function-type-to a))
                          (substitute-type a-subs (function-type-to b)))))
      (append a-subs b-subs)))
   ;; Oh no.
   (else
    (error "type mismtach" a b))))

;; Successively transform the type environment by applying
;; constraints.  If there are no type mismatches or other errors then
;; a new type environment in which all type variables have been
;; removed is returned.
(define (solve-constraints env constraints)
  (match constraints
    (() env)
    (((a . b) . rest)
     ;; First, attempt to unify the 2 types in the constraint and get
     ;; the substitutions that unification creates.  Subsitutions need
     ;; to be applied to the type environment *and* the remaining
     ;; constraints.
     (let* ((new-subs (unify a b)))
       (solve-constraints (substitute-env new-subs env)
                          (substitute-constraints new-subs rest))))))

(define (infer exp)
  (define-values (env constraints)
    (make-constraints* exp))
  (assq-ref (solve-constraints env constraints) exp))

(define (test-equal a b)
  (unless (equal? a b)
    (error "fail:" a b)))

(test-equal (infer 6) int)
(test-equal (infer #t) bool)
(test-equal (infer #f) bool)
(test-equal (false-if-exception (infer 'x)) #f)
(test-equal (infer '(lambda (x) (not x))) (make-function-type bool bool))
(test-equal (infer '((lambda (x) x) 6)) int)
(test-equal (infer '((lambda (x) (if (not #t) (add1 x) (sub1 x))) 1)) int)
(test-equal (false-if-exception (infer '((lambda (x) (if #t 1 x)) #f))) #f)