diff options
author | David Thompson <dthompson2@worcester.edu> | 2023-01-24 09:01:22 -0500 |
---|---|---|
committer | David Thompson <dthompson2@worcester.edu> | 2023-06-08 08:14:41 -0400 |
commit | 4df700e5e285e4aaf4c1b7f2a87ca399d7c534e0 (patch) | |
tree | d405b4849c54b9aa6a3bc694856536d0e2157185 | |
parent | 5176e002c6cac030828713d4111b7571446c074d (diff) |
Shader outputs.
-rw-r--r-- | chickadee/graphics/seagull.scm | 234 |
1 files changed, 186 insertions, 48 deletions
diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm index 515101e..0a54b88 100644 --- a/chickadee/graphics/seagull.scm +++ b/chickadee/graphics/seagull.scm @@ -62,6 +62,9 @@ (char? x) (boolean? x))) +(define (unary-operator? x) + (eq? x 'not)) + (define (arithmetic-operator? x) (memq x '(+ - * /))) @@ -75,12 +78,30 @@ (define (vector-constructor? x) (memq x '(vec2 vec3 vec4))) +(define (conversion? x) + (memq x '(int->float float->int))) + +(define (math-function? x) + (memq x '(abs sqrt min max sin cos tan clamp mix))) + (define (primitive-call? x) (or (binary-operator? x) - (vector-constructor? x))) - -(define (input-qualifier? x) - (memq x '(in uniform))) + (unary-operator? x) + (vector-constructor? x) + (conversion? x) + (math-function? x))) + +(define (top-level-qualifier? x) + (memq x '(in out uniform))) + +(define (built-in-output? name stage) + (case stage + ((vertex) + ;; GL 4+ has more built-ins, but we are supporting GL 2+ so we + ;; can't use them easily. + (memq name '(gl-position gl-point-size gl-clip-distance))) + ((fragment) + (memq name '(gl-frag-depth gl-sample-mask))))) (define (difference a b) (match a @@ -284,11 +305,20 @@ (define (expand:top-level qualifiers types names body env) (let* ((env* (compose-envs (alpha-convert names) env))) + ;; TODO: Support interpolation qualifiers. `(top-level ,(map (lambda (qualifier type name) (list qualifier type (lookup name env*))) qualifiers types names) ,(expand body env*)))) +(define (expand:outputs names exps env) + `(outputs ,@(map (lambda (name exp) + (list (if (built-in-output? name 'vertex) + name + (lookup name env)) + (expand exp env))) + names exps))) + (define &seagull-syntax-error (make-exception-type '&seagull-syntax-error &error '(form))) @@ -315,7 +345,9 @@ (expand:lambda params body env)) (('values exps ...) (expand:values exps env)) - (('top-level (((? input-qualifier? qualifiers) types names) ...) + (('outputs ((? symbol? names) exps) ...) + (expand:outputs names exps env)) + (('top-level (((? top-level-qualifier? qualifiers) types names) ...) body) (expand:top-level qualifiers types names body env)) ;; Macros: @@ -431,6 +463,11 @@ `(top-level ,inputs ,(propagate-constants body env))) +(define (propagate:outputs names exps env) + `(outputs ,@(map (lambda (name exp) + (list name (propagate-constants exp env))) + names exps))) + (define (propagate-constants exp env) (match exp ((? immediate?) exp) @@ -450,6 +487,8 @@ (propagate:primcall operator args env)) (('call operator args ...) (propagate:call operator args env)) + (('outputs (names exps) ...) + (propagate:outputs names exps env)) (('top-level inputs body) (propagate:top-level inputs body env)))) @@ -509,6 +548,8 @@ ((or ('primcall _ args ...) ('call args ...)) (check-free-variables-in-list args bound-vars top-level-vars)) + (('outputs (names exps) ...) + (check-free-variables-in-list exps bound-vars top-level-vars)) (('top-level ((_ _ names) ...) body) (define bound-vars* (append names bound-vars)) (check-free-variables body bound-vars* top-level-vars)))) @@ -570,6 +611,14 @@ (values `(top-level ,inputs ,body*) body-env)) +(define (hoist:outputs names exps) + (define-values (exps* exp-env) + (hoist:list exps)) + (values `(outputs ,@(map (lambda (name exp) + (list name exp)) + names exps*)) + exp-env)) + (define (hoist-functions exp) (match exp ((or (? immediate?) ('var _ _)) @@ -586,6 +635,8 @@ (hoist:primcall operator args)) (('call args ...) (hoist:call args)) + (('outputs (names exps) ...) + (hoist:outputs names exps)) (('top-level inputs body) (hoist:top-level inputs body)))) @@ -612,15 +663,6 @@ ;;; -;;; Recursion analysis -;;; - -;; GLSL does not allow for recursive function calls, but Seagull -;; allows the current function to call itself in tail position because -;; they can be compiled to a loop in GLSL. - - -;;; ;;; Typed expressions ;;; @@ -737,6 +779,14 @@ (match type (('overload types) types))) +;; Outputs type: +(define outputs-type '(outputs)) + +(define (outputs-type? obj) + (match obj + (('outputs) #t) + (_ #f))) + (define (type? obj) (or (primitive-type? obj) (type-variable? obj) @@ -793,9 +843,9 @@ (function-type (list mat4-type mat4-type) (list mat4-type)) (function-type (list mat4-type vec4-type) - (list mat4-type)) + (list vec4-type)) (function-type (list vec4-type mat4-type) - (list mat4-type))))) + (list vec4-type))))) (define div-type (list (overload-type @@ -855,6 +905,31 @@ (function-type (list float-type vec3-type) (list vec4-type))))) +(define abs-type + (list (overload-type + (function-type (list int-type) (list int-type)) + (function-type (list float-type) (list float-type))))) + +(define sqrt-type + (list (overload-type + (function-type (list int-type) (list int-type)) + (function-type (list float-type) (list float-type))))) + +(define min/max-type + (list (overload-type + (function-type (list int-type int-type) (list int-type)) + (function-type (list float-type float-type) (list float-type))))) + +(define trigonometry-type + (list (function-type (list float-type) (list float-type)))) + +(define clamp/mix-type + (list (overload-type + (function-type (list int-type int-type int-type) + (list int-type)) + (function-type (list float-type float-type float-type) + (list float-type))))) + (define (top-level-type-env) `((+ . ,add/sub-type) (- . ,add/sub-type) @@ -867,7 +942,20 @@ (>= . ,comparison-type) (vec2 . ,make-vec2-type) (vec3 . ,make-vec3-type) - (vec4 . ,make-vec4-type))) + (vec4 . ,make-vec4-type) + (not ,(function-type (list bool-type) (list bool-type))) + (int->float ,(function-type (list int-type) (list float-type))) + (float->int ,(function-type (list float-type) (list int-type))) + (abs . ,abs-type) + (sqrt . ,sqrt-type) + (min . ,min/max-type) + (max . ,min/max-type) + (sin . ,trigonometry-type) + (cos . ,trigonometry-type) + (tan . ,trigonometry-type) + (clamp . ,clamp/mix-type) + (mix . ,clamp/mix-type) + (gl-position ,vec4-type))) (define (occurs? a b) (cond @@ -883,7 +971,9 @@ (define (apply-substitution-to-type type from to) (pk 'apply-substitution-to-type type from to) (cond - ((primitive-type? type) type) + ((or (primitive-type? type) + (outputs-type? type)) + type) ((type-variable? type) (if (equal? type from) to type)) ((function-type? type) @@ -1007,15 +1097,14 @@ (define (annotate:if predicate consequent alternate env) (define consequent* (annotate-exp consequent env)) - (texp (fresh-type-variables-for-list - (texp-types consequent*)) + (texp (texp-types consequent*) `(if ,(annotate-exp predicate env) ,consequent* ,(annotate-exp alternate env)))) (define (annotate:let names exps body env) (define exps* (annotate:list exps env)) - (define exp-types (map single-type exps*)) + (define exp-types (map texp-types exps*)) (define env* (fold extend-env env names exp-types)) (define body* (annotate-exp body env*)) (texp (texp-types body*) @@ -1025,7 +1114,10 @@ ;; Each function parameter gets a fresh type variable. (define param-types (fresh-type-variables-for-list params)) ;; The type environment is extended with the function parameters. - (define env* (fold extend-env env params param-types)) + (define env* + (fold (lambda (param type env*) + (extend-env param (list type) env*)) + env params param-types)) ;; The body is annotated in the new environment. (define body* (annotate-exp body env*)) (texp (list (function-type param-types (texp-types body*))) @@ -1068,11 +1160,19 @@ (define env* (extend-env name (single-type exp*) env)) (define result* (cons `(function ,name ,exp*) result)) (annotate:top-level rest body env* result*)) - ((((? input-qualifier? qualifier) type-name name) . rest) - (define env* (extend-env name (type-name->type type-name) env)) + ((((? top-level-qualifier? qualifier) type-name name) . rest) + (define env* (extend-env name (list (type-name->type type-name)) env)) (define result* (cons (list qualifier type-name name) result)) (annotate:top-level rest body env* result*)))) +(define (annotate:outputs names exps env) + (texp (list outputs-type) + `(outputs + ,@(map (lambda (name exp) + (list (texp (pk 'types name (lookup name env)) name) + (annotate-exp exp env))) + names exps)))) + (define (annotate-exp exp env) (match exp ((? exact-integer?) @@ -1082,7 +1182,7 @@ ((? boolean?) (texp (list bool-type) exp)) (('var var _) - (texp (list (lookup-type var env)) exp)) + (texp (lookup-type var env) exp)) (('if predicate consequent alternate) (annotate:if predicate consequent alternate env)) (('let ((names exps) ...) body) @@ -1095,6 +1195,8 @@ (annotate:primitive-call operator args env)) (('call operator args ...) (annotate:call operator args env)) + (('outputs (names exps) ...) + (annotate:outputs names exps env)) (('top-level bindings body) (annotate:top-level bindings body env)))) @@ -1200,6 +1302,8 @@ (unify:functions a b success)) ((a (? overload-type? b)) (unify:overload b a success)) + (((? outputs-type?) (? outputs-type?)) + (success '())) ((() ()) (success '())) (((a rest-a ...) (b rest-b ...)) @@ -1208,7 +1312,6 @@ (unify:fail "type mismatch" a b)))) (define (infer:list exps subs success) - (pk 'infer:list exps) (match exps (() (success subs)) ((exp . rest) @@ -1222,7 +1325,6 @@ (compose-substitutions sub0 sub1))))))))) (define (infer:if types predicate consequent alternate subs success) - (pk 'infer:if predicate consequent alternate subs) (infer predicate subs (lambda (sub0) @@ -1243,22 +1345,11 @@ (texp-types alternate) sub4) (lambda (sub5) - (define sub6 + (success (compose-substitutions - sub4 sub5)) - (unify (apply-substitutions-to-types - types - sub6) - (apply-substitutions-to-types - (texp-types consequent) - sub6) - (lambda (sub7) - (success - (compose-substitutions - sub6 sub7))))))))))))))) + sub4 sub5))))))))))))) (define (infer:lambda type body subs success) - (pk 'infer:lambda type body) (define type* (if (for-all-type? type) (for-all-type-ref type) @@ -1273,11 +1364,9 @@ (compose-substitutions sub0 sub1))))))) (define (infer:values types exps subs success) - (pk 'infer:values exps) (infer:list exps subs success)) (define (infer:call types operator args subs success) - (pk 'infer:call types operator args subs) (infer operator subs (lambda (sub0) @@ -1296,7 +1385,6 @@ (success (compose-substitutions sub1 sub2))))))))) (define (infer:primcall types operator args subs success) - (pk 'infer:primcall types operator args) (infer:list args subs @@ -1313,7 +1401,6 @@ (compose-substitutions sub0 sub1))))))) (define (infer:let types names exps body subs success) - (pk 'infer:let types names exps subs) (infer:list exps subs @@ -1346,10 +1433,21 @@ (success (compose-substitutions sub1 sub2))))))))) +(define (infer:outputs names exps subs success) + (define output-types (map single-type names)) + (infer:list + exps + subs + (lambda (sub0) + (unify output-types + (apply-substitutions-to-types (map single-type exps) sub0) + (lambda (sub1) + (success + (compose-substitutions sub0 sub1))))))) + (define (infer exp subs success) (match exp (('t types (or (? immediate?) ('var _ _))) - (pk 'infer:basic) (success subs)) (('t types ('if predicate consequent alternate)) (infer:if types predicate consequent alternate subs success)) @@ -1363,6 +1461,8 @@ (infer:primcall types operator args subs success)) (('t types ('call operator args ...)) (infer:call types operator args subs success)) + (('t _ ('outputs (names exps) ...)) + (infer:outputs names exps subs success)) (('t types ('top-level bindings body)) (infer:top-level types bindings body subs success)) (_ (error "unknown form" exp)))) @@ -1401,7 +1501,6 @@ (infer annotated '() (lambda (subs) - (pk 'result-subs subs) (resolve annotated subs))))) (lambda args (apply error args)))) @@ -1455,6 +1554,18 @@ (type-name type) temp a-temp op* b-temp) (list temp)) +(define (emit:unary-operator type op a port level) + (define op* + (case op + ((not) '!) + (else op))) + (define a-temp (single-temp (emit-glsl a port level))) + (define temp (unique-identifier)) + (indent level port) + (format port "~a ~a = ~a(~a);\n" + (type-name type) temp op* a-temp) + (list temp)) + (define (emit:declaration type lhs rhs port level) (indent level port) (if rhs @@ -1538,6 +1649,11 @@ let-temps) (define (emit:primcall type operator args port level) + (define operator* + (case operator + ((float->int) 'float) + ((int->float) 'int) + (else operator))) (define arg-temps (map (lambda (arg) (single-temp (emit-glsl arg port level))) @@ -1547,7 +1663,7 @@ (format port "~a ~a = ~a(~a);\n" (type-name type) output-temp - operator + operator* (string-join (map symbol->string arg-temps) ", ")) (list output-temp)) @@ -1568,7 +1684,7 @@ (define (emit:top-level bindings body port level) (for-each (match-lambda - (((? input-qualifier? qualifier) type-name name) + (((? top-level-qualifier? qualifier) type-name name) (format port "~a ~a ~a;\n" qualifier type-name name)) (('function name ('t (type) ('lambda params body))) (emit:function name type params body port level))) @@ -1577,6 +1693,24 @@ (emit-glsl body port (+ level 1)) (display "}\n" port)) +(define (emit:outputs names exps port level) + (define (output-name name) + (case name + ((gl-position) 'gl_Position) + ((gl-point-size) 'gl_PointSize) + ((gl-clip-distance) 'gl_ClipDistance) + ((gl-frag-depth) 'gl_FragDepth) + ((gl-sample-mask) 'gl_SampleMask) + (else name))) + (for-each (lambda (name exp) + (match (emit-glsl exp port level) + ((temp) + (indent level port) + (format port "~a = ~a;\n" + (output-name (texp-exp name)) + temp)))) + names exps)) + (define* (emit-glsl exp port #:optional (level 0)) (match exp (('t _ (? exact-integer? n)) @@ -1595,10 +1729,14 @@ (emit:let types names exps body port level)) (('t (type) ('primcall ('t _ (? binary-operator? op)) a b)) (emit:binary-operator type op a b port level)) + (('t (type) ('primcall ('t _ (? unary-operator? op)) a)) + (emit:unary-operator type op a port level)) (('t (type) ('primcall ('t _ op) args ...)) (emit:primcall type op args port level)) (('t types ('call operator args ...)) (emit:call types operator args port level)) + (('t _ ('outputs (names exps) ...)) + (emit:outputs names exps port level)) (('t _ ('top-level (bindings ...) body)) (emit:top-level bindings body port level)))) |