From d7d6b0d739ef1f08c4765c625d0428592fdc67c3 Mon Sep 17 00:00:00 2001 From: David Thompson Date: Sat, 4 Feb 2023 11:40:26 -0500 Subject: Infer all possible forms. --- chickadee/graphics/seagull.scm | 133 ++++++++++++++++++++++++++++++++++------- 1 file changed, 111 insertions(+), 22 deletions(-) diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm index 1ae82bb..05003d9 100644 --- a/chickadee/graphics/seagull.scm +++ b/chickadee/graphics/seagull.scm @@ -1661,6 +1661,83 @@ combined-subs pred)) +(define (infer:values exps env) + (define-values (exps* exp-subs exp-pred) + (infer:list exps env)) + (values (texp (map single-type exps*) + `(values ,@exps*)) + exp-subs + exp-pred)) + +(define (infer:outputs names exps env) + (define-values (exps* exp-subs exp-pred) + (infer:list exps env)) + (values (texp (map single-type exps*) + `(outputs ,@(map list names exps*))) + exp-subs + exp-pred)) + +(define (infer:top-level bindings body env) + (define (infer-bindings bindings texps subs pred) + (match bindings + (() + (values (reverse texps) subs pred)) + ((('function _ exp) . rest) + (define-values (texp subs* pred*) + (infer-exp exp env)) + (define-values (new-pred combined-subs) + (eval-predicate* (compose-predicates pred pred*) + (compose-substitutions subs subs*))) + (infer-bindings rest + (cons texp texps) + combined-subs + new-pred)) + (((_ type-name _) . rest) + (infer-bindings rest + (cons (list (type-name->type type-name)) texps) + subs + pred)))) + (define qualifiers (map first bindings)) + (define names + (map (match-lambda + (('function name _) name) + ((_ _ name) name)) + bindings)) + (define type-names + (map (match-lambda + (((or 'in 'out) type-name _) type-name) + (_ #f)) + bindings)) + (define-values (exps exp-subs exp-pred) + (infer-bindings bindings '() '() #t)) + (define exp-types + (map (lambda (x) + (if (texp? x) + (texp-types x) + x)) + exps)) + (define env* + (fold extend-env + (apply-substitutions-to-env env exp-subs) + names + exp-types)) + (define-values (body* body-subs body-pred) + (infer-exp body env*)) + (define-values (pred combined-subs) + (eval-predicate* (compose-predicates exp-pred body-pred) + (compose-substitutions exp-subs body-subs))) + (define bindings* + (map (match-lambda* + (((and (or 'in 'out) qualifier) type-name name _) + (list qualifier type-name name)) + (('function _ name exp) + `(function ,name ,exp))) + qualifiers type-names names exps)) + (values (texp (texp-types body*) + `(top-level ,bindings* ,body*)) + combined-subs + pred)) + ;; Inference returns 3 values: ;; - a typed expression ;; - a list of substitutions @@ -1669,40 +1746,49 @@ (match exp ((? immediate?) (infer:immediate exp)) - ((? symbol?) - (infer:variable exp env)) + ((or (? symbol? name) ('var name _)) + (infer:variable name env)) (('if predicate consequent alternate) (infer:if predicate consequent alternate env)) (('let ((names exps) ...) body) (infer:let names exps body env)) (('lambda (params ...) body) (infer:lambda params body env)) - ;; (('values exps ...) - ;; (infer:values exps env)) + (('values exps ...) + (infer:values exps env)) (('primcall operator args ...) (infer:primitive-call operator args env)) (('call operator args ...) (infer:call operator args env)) - ;; (('outputs (names exps) ...) - ;; (infer:outputs names exps env)) - ;; (('top-level bindings body) - ;; (infer:top-level bindings body env)) + (('outputs (names exps) ...) + (infer:outputs names exps env)) + (('top-level bindings body) + (infer:top-level bindings body env)) (_ (error "unknown form" exp)))) +(define (make-test-env) + (extend-env + '+ + (list (let ((a (fresh-type-variable)) + (b (fresh-type-variable)) + (c (fresh-type-variable))) + (for-all-type + (list a b c) + (function-type (list a b) + (list c)) + `(or (when (and (= ,a ,int-type) + (= ,b ,int-type)) + (substitute ,c ,a)) + (when (and (= ,a ,float-type) + (= ,b ,float-type)) + (substitute ,c ,a)))))) + (empty-env))) + ;; TODO: Add some kind of context object that is threaded through the ;; inference process so that when a type error occurs we can show the ;; expression that caused it. (define (infer-types exp stage) - (call-with-unify-rollback - (lambda () - (let ((annotated (pk 'annotated (annotate-exp* exp stage)))) - (infer annotated - '() - (lambda (subs) - (resolve annotated subs))))) - (match-lambda* - ((msg . args) - (seagull-type-error msg args infer-types))))) + (infer-exp exp (make-test-env))) ;;; @@ -1929,7 +2015,7 @@ (emit:float n version port level)) (('t _ (? boolean? b)) (emit:boolean b version port level)) - (('t _ ('var var _)) + (('t _ (? symbol? var)) (list var)) (('t _ ('if predicate consequent alternate)) (emit:if predicate consequent alternate version port level)) @@ -1937,11 +2023,11 @@ (emit:values exps version port level)) (('t types ('let ((names exps) ...) body)) (emit:let types names exps body version port level)) - (('t (type) ('primcall ('t _ (? binary-operator? op)) a b)) + (('t (type) ('primcall (? binary-operator? op) a b)) (emit:binary-operator type op a b version port level)) - (('t (type) ('primcall ('t _ (? unary-operator? op)) a)) + (('t (type) ('primcall (? unary-operator? op) a)) (emit:unary-operator type op a version port level)) - (('t (type) ('primcall ('t _ op) args ...)) + (('t (type) ('primcall op args ...)) (emit:primcall type op args version port level)) (('t types ('call operator args ...)) (emit:call types operator args version port level)) @@ -1958,6 +2044,8 @@ ;; Combine all of the compiler passes on a user provided program and ;; emit GLSL code if the program is valid. +(use-modules (ice-9 pretty-print)) + (define* (compile-seagull exp #:key (stage 'vertex) (version '330) (port (current-output-port))) @@ -1967,4 +2055,5 @@ (propagated (propagate-constants expanded (empty-env))) (hoisted (hoist-functions* propagated)) (inferred (infer-types hoisted stage))) + (pretty-print inferred) (emit-glsl inferred version port)))) -- cgit v1.2.3