summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Thompson <dthompson2@worcester.edu>2023-02-04 11:40:26 -0500
committerDavid Thompson <dthompson2@worcester.edu>2023-06-08 08:14:41 -0400
commitd7d6b0d739ef1f08c4765c625d0428592fdc67c3 (patch)
tree6421cdc8a3c1817b38c2ddb14db4268e3b07027c
parentcbc4c95a0820934600a26329ee58cc5685060dcb (diff)
Infer all possible forms.
-rw-r--r--chickadee/graphics/seagull.scm133
1 files changed, 111 insertions, 22 deletions
diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm
index 1ae82bb..05003d9 100644
--- a/chickadee/graphics/seagull.scm
+++ b/chickadee/graphics/seagull.scm
@@ -1661,6 +1661,83 @@
combined-subs
pred))
+(define (infer:values exps env)
+ (define-values (exps* exp-subs exp-pred)
+ (infer:list exps env))
+ (values (texp (map single-type exps*)
+ `(values ,@exps*))
+ exp-subs
+ exp-pred))
+
+(define (infer:outputs names exps env)
+ (define-values (exps* exp-subs exp-pred)
+ (infer:list exps env))
+ (values (texp (map single-type exps*)
+ `(outputs ,@(map list names exps*)))
+ exp-subs
+ exp-pred))
+
+(define (infer:top-level bindings body env)
+ (define (infer-bindings bindings texps subs pred)
+ (match bindings
+ (()
+ (values (reverse texps) subs pred))
+ ((('function _ exp) . rest)
+ (define-values (texp subs* pred*)
+ (infer-exp exp env))
+ (define-values (new-pred combined-subs)
+ (eval-predicate* (compose-predicates pred pred*)
+ (compose-substitutions subs subs*)))
+ (infer-bindings rest
+ (cons texp texps)
+ combined-subs
+ new-pred))
+ (((_ type-name _) . rest)
+ (infer-bindings rest
+ (cons (list (type-name->type type-name)) texps)
+ subs
+ pred))))
+ (define qualifiers (map first bindings))
+ (define names
+ (map (match-lambda
+ (('function name _) name)
+ ((_ _ name) name))
+ bindings))
+ (define type-names
+ (map (match-lambda
+ (((or 'in 'out) type-name _) type-name)
+ (_ #f))
+ bindings))
+ (define-values (exps exp-subs exp-pred)
+ (infer-bindings bindings '() '() #t))
+ (define exp-types
+ (map (lambda (x)
+ (if (texp? x)
+ (texp-types x)
+ x))
+ exps))
+ (define env*
+ (fold extend-env
+ (apply-substitutions-to-env env exp-subs)
+ names
+ exp-types))
+ (define-values (body* body-subs body-pred)
+ (infer-exp body env*))
+ (define-values (pred combined-subs)
+ (eval-predicate* (compose-predicates exp-pred body-pred)
+ (compose-substitutions exp-subs body-subs)))
+ (define bindings*
+ (map (match-lambda*
+ (((and (or 'in 'out) qualifier) type-name name _)
+ (list qualifier type-name name))
+ (('function _ name exp)
+ `(function ,name ,exp)))
+ qualifiers type-names names exps))
+ (values (texp (texp-types body*)
+ `(top-level ,bindings* ,body*))
+ combined-subs
+ pred))
+
;; Inference returns 3 values:
;; - a typed expression
;; - a list of substitutions
@@ -1669,40 +1746,49 @@
(match exp
((? immediate?)
(infer:immediate exp))
- ((? symbol?)
- (infer:variable exp env))
+ ((or (? symbol? name) ('var name _))
+ (infer:variable name env))
(('if predicate consequent alternate)
(infer:if predicate consequent alternate env))
(('let ((names exps) ...) body)
(infer:let names exps body env))
(('lambda (params ...) body)
(infer:lambda params body env))
- ;; (('values exps ...)
- ;; (infer:values exps env))
+ (('values exps ...)
+ (infer:values exps env))
(('primcall operator args ...)
(infer:primitive-call operator args env))
(('call operator args ...)
(infer:call operator args env))
- ;; (('outputs (names exps) ...)
- ;; (infer:outputs names exps env))
- ;; (('top-level bindings body)
- ;; (infer:top-level bindings body env))
+ (('outputs (names exps) ...)
+ (infer:outputs names exps env))
+ (('top-level bindings body)
+ (infer:top-level bindings body env))
(_ (error "unknown form" exp))))
+(define (make-test-env)
+ (extend-env
+ '+
+ (list (let ((a (fresh-type-variable))
+ (b (fresh-type-variable))
+ (c (fresh-type-variable)))
+ (for-all-type
+ (list a b c)
+ (function-type (list a b)
+ (list c))
+ `(or (when (and (= ,a ,int-type)
+ (= ,b ,int-type))
+ (substitute ,c ,a))
+ (when (and (= ,a ,float-type)
+ (= ,b ,float-type))
+ (substitute ,c ,a))))))
+ (empty-env)))
+
;; TODO: Add some kind of context object that is threaded through the
;; inference process so that when a type error occurs we can show the
;; expression that caused it.
(define (infer-types exp stage)
- (call-with-unify-rollback
- (lambda ()
- (let ((annotated (pk 'annotated (annotate-exp* exp stage))))
- (infer annotated
- '()
- (lambda (subs)
- (resolve annotated subs)))))
- (match-lambda*
- ((msg . args)
- (seagull-type-error msg args infer-types)))))
+ (infer-exp exp (make-test-env)))
;;;
@@ -1929,7 +2015,7 @@
(emit:float n version port level))
(('t _ (? boolean? b))
(emit:boolean b version port level))
- (('t _ ('var var _))
+ (('t _ (? symbol? var))
(list var))
(('t _ ('if predicate consequent alternate))
(emit:if predicate consequent alternate version port level))
@@ -1937,11 +2023,11 @@
(emit:values exps version port level))
(('t types ('let ((names exps) ...) body))
(emit:let types names exps body version port level))
- (('t (type) ('primcall ('t _ (? binary-operator? op)) a b))
+ (('t (type) ('primcall (? binary-operator? op) a b))
(emit:binary-operator type op a b version port level))
- (('t (type) ('primcall ('t _ (? unary-operator? op)) a))
+ (('t (type) ('primcall (? unary-operator? op) a))
(emit:unary-operator type op a version port level))
- (('t (type) ('primcall ('t _ op) args ...))
+ (('t (type) ('primcall op args ...))
(emit:primcall type op args version port level))
(('t types ('call operator args ...))
(emit:call types operator args version port level))
@@ -1958,6 +2044,8 @@
;; Combine all of the compiler passes on a user provided program and
;; emit GLSL code if the program is valid.
+(use-modules (ice-9 pretty-print))
+
(define* (compile-seagull exp #:key (stage 'vertex)
(version '330)
(port (current-output-port)))
@@ -1967,4 +2055,5 @@
(propagated (propagate-constants expanded (empty-env)))
(hoisted (hoist-functions* propagated))
(inferred (infer-types hoisted stage)))
+ (pretty-print inferred)
(emit-glsl inferred version port))))