summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Thompson <dthompson2@worcester.edu>2023-01-24 09:01:22 -0500
committerDavid Thompson <dthompson2@worcester.edu>2023-06-08 08:14:41 -0400
commit4df700e5e285e4aaf4c1b7f2a87ca399d7c534e0 (patch)
treed405b4849c54b9aa6a3bc694856536d0e2157185
parent5176e002c6cac030828713d4111b7571446c074d (diff)
Shader outputs.
-rw-r--r--chickadee/graphics/seagull.scm234
1 files changed, 186 insertions, 48 deletions
diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm
index 515101e..0a54b88 100644
--- a/chickadee/graphics/seagull.scm
+++ b/chickadee/graphics/seagull.scm
@@ -62,6 +62,9 @@
(char? x)
(boolean? x)))
+(define (unary-operator? x)
+ (eq? x 'not))
+
(define (arithmetic-operator? x)
(memq x '(+ - * /)))
@@ -75,12 +78,30 @@
(define (vector-constructor? x)
(memq x '(vec2 vec3 vec4)))
+(define (conversion? x)
+ (memq x '(int->float float->int)))
+
+(define (math-function? x)
+ (memq x '(abs sqrt min max sin cos tan clamp mix)))
+
(define (primitive-call? x)
(or (binary-operator? x)
- (vector-constructor? x)))
-
-(define (input-qualifier? x)
- (memq x '(in uniform)))
+ (unary-operator? x)
+ (vector-constructor? x)
+ (conversion? x)
+ (math-function? x)))
+
+(define (top-level-qualifier? x)
+ (memq x '(in out uniform)))
+
+(define (built-in-output? name stage)
+ (case stage
+ ((vertex)
+ ;; GL 4+ has more built-ins, but we are supporting GL 2+ so we
+ ;; can't use them easily.
+ (memq name '(gl-position gl-point-size gl-clip-distance)))
+ ((fragment)
+ (memq name '(gl-frag-depth gl-sample-mask)))))
(define (difference a b)
(match a
@@ -284,11 +305,20 @@
(define (expand:top-level qualifiers types names body env)
(let* ((env* (compose-envs (alpha-convert names) env)))
+ ;; TODO: Support interpolation qualifiers.
`(top-level ,(map (lambda (qualifier type name)
(list qualifier type (lookup name env*)))
qualifiers types names)
,(expand body env*))))
+(define (expand:outputs names exps env)
+ `(outputs ,@(map (lambda (name exp)
+ (list (if (built-in-output? name 'vertex)
+ name
+ (lookup name env))
+ (expand exp env)))
+ names exps)))
+
(define &seagull-syntax-error
(make-exception-type '&seagull-syntax-error &error '(form)))
@@ -315,7 +345,9 @@
(expand:lambda params body env))
(('values exps ...)
(expand:values exps env))
- (('top-level (((? input-qualifier? qualifiers) types names) ...)
+ (('outputs ((? symbol? names) exps) ...)
+ (expand:outputs names exps env))
+ (('top-level (((? top-level-qualifier? qualifiers) types names) ...)
body)
(expand:top-level qualifiers types names body env))
;; Macros:
@@ -431,6 +463,11 @@
`(top-level ,inputs
,(propagate-constants body env)))
+(define (propagate:outputs names exps env)
+ `(outputs ,@(map (lambda (name exp)
+ (list name (propagate-constants exp env)))
+ names exps)))
+
(define (propagate-constants exp env)
(match exp
((? immediate?) exp)
@@ -450,6 +487,8 @@
(propagate:primcall operator args env))
(('call operator args ...)
(propagate:call operator args env))
+ (('outputs (names exps) ...)
+ (propagate:outputs names exps env))
(('top-level inputs body)
(propagate:top-level inputs body env))))
@@ -509,6 +548,8 @@
((or ('primcall _ args ...)
('call args ...))
(check-free-variables-in-list args bound-vars top-level-vars))
+ (('outputs (names exps) ...)
+ (check-free-variables-in-list exps bound-vars top-level-vars))
(('top-level ((_ _ names) ...) body)
(define bound-vars* (append names bound-vars))
(check-free-variables body bound-vars* top-level-vars))))
@@ -570,6 +611,14 @@
(values `(top-level ,inputs ,body*)
body-env))
+(define (hoist:outputs names exps)
+ (define-values (exps* exp-env)
+ (hoist:list exps))
+ (values `(outputs ,@(map (lambda (name exp)
+ (list name exp))
+ names exps*))
+ exp-env))
+
(define (hoist-functions exp)
(match exp
((or (? immediate?) ('var _ _))
@@ -586,6 +635,8 @@
(hoist:primcall operator args))
(('call args ...)
(hoist:call args))
+ (('outputs (names exps) ...)
+ (hoist:outputs names exps))
(('top-level inputs body)
(hoist:top-level inputs body))))
@@ -612,15 +663,6 @@
;;;
-;;; Recursion analysis
-;;;
-
-;; GLSL does not allow for recursive function calls, but Seagull
-;; allows the current function to call itself in tail position because
-;; they can be compiled to a loop in GLSL.
-
-
-;;;
;;; Typed expressions
;;;
@@ -737,6 +779,14 @@
(match type
(('overload types) types)))
+;; Outputs type:
+(define outputs-type '(outputs))
+
+(define (outputs-type? obj)
+ (match obj
+ (('outputs) #t)
+ (_ #f)))
+
(define (type? obj)
(or (primitive-type? obj)
(type-variable? obj)
@@ -793,9 +843,9 @@
(function-type (list mat4-type mat4-type)
(list mat4-type))
(function-type (list mat4-type vec4-type)
- (list mat4-type))
+ (list vec4-type))
(function-type (list vec4-type mat4-type)
- (list mat4-type)))))
+ (list vec4-type)))))
(define div-type
(list (overload-type
@@ -855,6 +905,31 @@
(function-type (list float-type vec3-type)
(list vec4-type)))))
+(define abs-type
+ (list (overload-type
+ (function-type (list int-type) (list int-type))
+ (function-type (list float-type) (list float-type)))))
+
+(define sqrt-type
+ (list (overload-type
+ (function-type (list int-type) (list int-type))
+ (function-type (list float-type) (list float-type)))))
+
+(define min/max-type
+ (list (overload-type
+ (function-type (list int-type int-type) (list int-type))
+ (function-type (list float-type float-type) (list float-type)))))
+
+(define trigonometry-type
+ (list (function-type (list float-type) (list float-type))))
+
+(define clamp/mix-type
+ (list (overload-type
+ (function-type (list int-type int-type int-type)
+ (list int-type))
+ (function-type (list float-type float-type float-type)
+ (list float-type)))))
+
(define (top-level-type-env)
`((+ . ,add/sub-type)
(- . ,add/sub-type)
@@ -867,7 +942,20 @@
(>= . ,comparison-type)
(vec2 . ,make-vec2-type)
(vec3 . ,make-vec3-type)
- (vec4 . ,make-vec4-type)))
+ (vec4 . ,make-vec4-type)
+ (not ,(function-type (list bool-type) (list bool-type)))
+ (int->float ,(function-type (list int-type) (list float-type)))
+ (float->int ,(function-type (list float-type) (list int-type)))
+ (abs . ,abs-type)
+ (sqrt . ,sqrt-type)
+ (min . ,min/max-type)
+ (max . ,min/max-type)
+ (sin . ,trigonometry-type)
+ (cos . ,trigonometry-type)
+ (tan . ,trigonometry-type)
+ (clamp . ,clamp/mix-type)
+ (mix . ,clamp/mix-type)
+ (gl-position ,vec4-type)))
(define (occurs? a b)
(cond
@@ -883,7 +971,9 @@
(define (apply-substitution-to-type type from to)
(pk 'apply-substitution-to-type type from to)
(cond
- ((primitive-type? type) type)
+ ((or (primitive-type? type)
+ (outputs-type? type))
+ type)
((type-variable? type)
(if (equal? type from) to type))
((function-type? type)
@@ -1007,15 +1097,14 @@
(define (annotate:if predicate consequent alternate env)
(define consequent* (annotate-exp consequent env))
- (texp (fresh-type-variables-for-list
- (texp-types consequent*))
+ (texp (texp-types consequent*)
`(if ,(annotate-exp predicate env)
,consequent*
,(annotate-exp alternate env))))
(define (annotate:let names exps body env)
(define exps* (annotate:list exps env))
- (define exp-types (map single-type exps*))
+ (define exp-types (map texp-types exps*))
(define env* (fold extend-env env names exp-types))
(define body* (annotate-exp body env*))
(texp (texp-types body*)
@@ -1025,7 +1114,10 @@
;; Each function parameter gets a fresh type variable.
(define param-types (fresh-type-variables-for-list params))
;; The type environment is extended with the function parameters.
- (define env* (fold extend-env env params param-types))
+ (define env*
+ (fold (lambda (param type env*)
+ (extend-env param (list type) env*))
+ env params param-types))
;; The body is annotated in the new environment.
(define body* (annotate-exp body env*))
(texp (list (function-type param-types (texp-types body*)))
@@ -1068,11 +1160,19 @@
(define env* (extend-env name (single-type exp*) env))
(define result* (cons `(function ,name ,exp*) result))
(annotate:top-level rest body env* result*))
- ((((? input-qualifier? qualifier) type-name name) . rest)
- (define env* (extend-env name (type-name->type type-name) env))
+ ((((? top-level-qualifier? qualifier) type-name name) . rest)
+ (define env* (extend-env name (list (type-name->type type-name)) env))
(define result* (cons (list qualifier type-name name) result))
(annotate:top-level rest body env* result*))))
+(define (annotate:outputs names exps env)
+ (texp (list outputs-type)
+ `(outputs
+ ,@(map (lambda (name exp)
+ (list (texp (pk 'types name (lookup name env)) name)
+ (annotate-exp exp env)))
+ names exps))))
+
(define (annotate-exp exp env)
(match exp
((? exact-integer?)
@@ -1082,7 +1182,7 @@
((? boolean?)
(texp (list bool-type) exp))
(('var var _)
- (texp (list (lookup-type var env)) exp))
+ (texp (lookup-type var env) exp))
(('if predicate consequent alternate)
(annotate:if predicate consequent alternate env))
(('let ((names exps) ...) body)
@@ -1095,6 +1195,8 @@
(annotate:primitive-call operator args env))
(('call operator args ...)
(annotate:call operator args env))
+ (('outputs (names exps) ...)
+ (annotate:outputs names exps env))
(('top-level bindings body)
(annotate:top-level bindings body env))))
@@ -1200,6 +1302,8 @@
(unify:functions a b success))
((a (? overload-type? b))
(unify:overload b a success))
+ (((? outputs-type?) (? outputs-type?))
+ (success '()))
((() ())
(success '()))
(((a rest-a ...) (b rest-b ...))
@@ -1208,7 +1312,6 @@
(unify:fail "type mismatch" a b))))
(define (infer:list exps subs success)
- (pk 'infer:list exps)
(match exps
(() (success subs))
((exp . rest)
@@ -1222,7 +1325,6 @@
(compose-substitutions sub0 sub1)))))))))
(define (infer:if types predicate consequent alternate subs success)
- (pk 'infer:if predicate consequent alternate subs)
(infer predicate
subs
(lambda (sub0)
@@ -1243,22 +1345,11 @@
(texp-types alternate)
sub4)
(lambda (sub5)
- (define sub6
+ (success
(compose-substitutions
- sub4 sub5))
- (unify (apply-substitutions-to-types
- types
- sub6)
- (apply-substitutions-to-types
- (texp-types consequent)
- sub6)
- (lambda (sub7)
- (success
- (compose-substitutions
- sub6 sub7)))))))))))))))
+ sub4 sub5)))))))))))))
(define (infer:lambda type body subs success)
- (pk 'infer:lambda type body)
(define type*
(if (for-all-type? type)
(for-all-type-ref type)
@@ -1273,11 +1364,9 @@
(compose-substitutions sub0 sub1)))))))
(define (infer:values types exps subs success)
- (pk 'infer:values exps)
(infer:list exps subs success))
(define (infer:call types operator args subs success)
- (pk 'infer:call types operator args subs)
(infer operator
subs
(lambda (sub0)
@@ -1296,7 +1385,6 @@
(success (compose-substitutions sub1 sub2)))))))))
(define (infer:primcall types operator args subs success)
- (pk 'infer:primcall types operator args)
(infer:list
args
subs
@@ -1313,7 +1401,6 @@
(compose-substitutions sub0 sub1)))))))
(define (infer:let types names exps body subs success)
- (pk 'infer:let types names exps subs)
(infer:list
exps
subs
@@ -1346,10 +1433,21 @@
(success
(compose-substitutions sub1 sub2)))))))))
+(define (infer:outputs names exps subs success)
+ (define output-types (map single-type names))
+ (infer:list
+ exps
+ subs
+ (lambda (sub0)
+ (unify output-types
+ (apply-substitutions-to-types (map single-type exps) sub0)
+ (lambda (sub1)
+ (success
+ (compose-substitutions sub0 sub1)))))))
+
(define (infer exp subs success)
(match exp
(('t types (or (? immediate?) ('var _ _)))
- (pk 'infer:basic)
(success subs))
(('t types ('if predicate consequent alternate))
(infer:if types predicate consequent alternate subs success))
@@ -1363,6 +1461,8 @@
(infer:primcall types operator args subs success))
(('t types ('call operator args ...))
(infer:call types operator args subs success))
+ (('t _ ('outputs (names exps) ...))
+ (infer:outputs names exps subs success))
(('t types ('top-level bindings body))
(infer:top-level types bindings body subs success))
(_ (error "unknown form" exp))))
@@ -1401,7 +1501,6 @@
(infer annotated
'()
(lambda (subs)
- (pk 'result-subs subs)
(resolve annotated subs)))))
(lambda args
(apply error args))))
@@ -1455,6 +1554,18 @@
(type-name type) temp a-temp op* b-temp)
(list temp))
+(define (emit:unary-operator type op a port level)
+ (define op*
+ (case op
+ ((not) '!)
+ (else op)))
+ (define a-temp (single-temp (emit-glsl a port level)))
+ (define temp (unique-identifier))
+ (indent level port)
+ (format port "~a ~a = ~a(~a);\n"
+ (type-name type) temp op* a-temp)
+ (list temp))
+
(define (emit:declaration type lhs rhs port level)
(indent level port)
(if rhs
@@ -1538,6 +1649,11 @@
let-temps)
(define (emit:primcall type operator args port level)
+ (define operator*
+ (case operator
+ ((float->int) 'float)
+ ((int->float) 'int)
+ (else operator)))
(define arg-temps
(map (lambda (arg)
(single-temp (emit-glsl arg port level)))
@@ -1547,7 +1663,7 @@
(format port "~a ~a = ~a(~a);\n"
(type-name type)
output-temp
- operator
+ operator*
(string-join (map symbol->string arg-temps) ", "))
(list output-temp))
@@ -1568,7 +1684,7 @@
(define (emit:top-level bindings body port level)
(for-each (match-lambda
- (((? input-qualifier? qualifier) type-name name)
+ (((? top-level-qualifier? qualifier) type-name name)
(format port "~a ~a ~a;\n" qualifier type-name name))
(('function name ('t (type) ('lambda params body)))
(emit:function name type params body port level)))
@@ -1577,6 +1693,24 @@
(emit-glsl body port (+ level 1))
(display "}\n" port))
+(define (emit:outputs names exps port level)
+ (define (output-name name)
+ (case name
+ ((gl-position) 'gl_Position)
+ ((gl-point-size) 'gl_PointSize)
+ ((gl-clip-distance) 'gl_ClipDistance)
+ ((gl-frag-depth) 'gl_FragDepth)
+ ((gl-sample-mask) 'gl_SampleMask)
+ (else name)))
+ (for-each (lambda (name exp)
+ (match (emit-glsl exp port level)
+ ((temp)
+ (indent level port)
+ (format port "~a = ~a;\n"
+ (output-name (texp-exp name))
+ temp))))
+ names exps))
+
(define* (emit-glsl exp port #:optional (level 0))
(match exp
(('t _ (? exact-integer? n))
@@ -1595,10 +1729,14 @@
(emit:let types names exps body port level))
(('t (type) ('primcall ('t _ (? binary-operator? op)) a b))
(emit:binary-operator type op a b port level))
+ (('t (type) ('primcall ('t _ (? unary-operator? op)) a))
+ (emit:unary-operator type op a port level))
(('t (type) ('primcall ('t _ op) args ...))
(emit:primcall type op args port level))
(('t types ('call operator args ...))
(emit:call types operator args port level))
+ (('t _ ('outputs (names exps) ...))
+ (emit:outputs names exps port level))
(('t _ ('top-level (bindings ...) body))
(emit:top-level bindings body port level))))