diff options
-rw-r--r-- | chickadee/graphics/seagull.scm | 247 |
1 files changed, 162 insertions, 85 deletions
diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm index 4e271ff..44a9bd9 100644 --- a/chickadee/graphics/seagull.scm +++ b/chickadee/graphics/seagull.scm @@ -39,12 +39,15 @@ ;; ;; TODO: ;; - Loops +;; - discard ;; - (define ...) form +;; - struct field aliases (rgba for vec4, for example) ;; - Scheme shader type -> GLSL struct translation ;; - Dead code elimination (error when a uniform is eliminated) ;; - User defined structs ;; - Multiple GLSL versions ;; - Better error messages (especially around type predicate failure) +;; - Refactor to add define-primitive syntax ;; ;;; Code: (define-module (chickadee graphics seagull) @@ -114,7 +117,7 @@ (memq x '(int->float float->int))) (define (math-function? x) - (memq x '(abs sqrt min max mod floor ceil sin cos tan clamp mix))) + (memq x '(abs sqrt min max mod floor ceil sin cos tan clamp mix length))) (define (vertex-primitive-call? x) #f) @@ -1235,8 +1238,8 @@ ,(apply-substitution-to-type b from to))) (('struct-field struct field var) `(struct-field ,(apply-substitution-to-type struct from to) - ,field - ,(apply-substitution-to-type var from to))) + ,field + ,(apply-substitution-to-type var from to))) (('array-element array var) `(array-element ,(apply-substitution-to-type array from to) ,(apply-substitution-to-type var from to))))) @@ -1405,7 +1408,10 @@ ;; struct type. (('struct-field struct field field-var) (if (struct-type? struct) - (values #t (list (cons field-var (struct-type-ref struct field)))) + (let ((field-type (struct-type-ref struct field))) + (if field-type + (values #t (list (cons field-var field-type))) + (values #f '()))) (values pred '()))) ;; Substitute the element var when array has been resolved to an ;; array type. @@ -1774,11 +1780,15 @@ (infer-exp exp env)) (define exp-type (single-type exp*)) (define tvar (fresh-variable-type)) - (values (texp (list tvar) - `(struct-ref ,exp* ,field)) - exp-subs - (compose-predicates exp-pred - (predicate:struct-field exp-type field tvar)))) + (define-values (pred combined-subs) + (eval-predicate* (compose-predicates (predicate:struct-field exp-type field tvar) + exp-pred) + exp-subs)) + (values (texp (list (apply-substitutions-to-type tvar combined-subs)) + `(struct-ref ,(apply-substitutions-to-texp exp* combined-subs) + ,field)) + combined-subs + pred)) (define (infer:array-ref array-exp index-exp env) (define-values (array-exp* array-exp-subs array-exp-pred) @@ -1823,7 +1833,11 @@ (eval-predicate* (compose-predicates exp-pred body-pred) (compose-substitutions exp-subs body-subs))) (values (texp (texp-types body*) - `(let ,(map list names exps*) ,body*)) + `(let ,(map (lambda (name exp) + (list name (apply-substitutions-to-texp + exp combined-subs))) + names exps*) + ,(apply-substitutions-to-texp body* combined-subs))) combined-subs pred)) @@ -1848,7 +1862,11 @@ (define-values (pred combined-subs) (eval-predicate* exp-pred (compose-substitutions exp-subs unify-subs))) (values (texp (list type:outputs) - `(outputs ,@(map list names exps*))) + `(outputs + ,@(map (lambda (name exp) + (list name (apply-substitutions-to-texp + exp combined-subs))) + names exps*))) combined-subs pred)) @@ -2025,8 +2043,8 @@ (type:vec4 type:float type:vec4) (type:float type:vec4 type:vec4) (type:mat3 type:mat3 type:mat3) - (type:mat3 type:vec3 type:mat3) - (type:vec3 type:mat3 type:mat3) + (type:mat3 type:vec3 type:vec3) + (type:vec3 type:mat3 type:vec3) (type:mat4 type:mat4 type:mat4) (type:mat4 type:vec4 type:vec4) (type:vec4 type:mat4 type:vec4))) @@ -2087,6 +2105,17 @@ (define type:make-vec4 (list (function-type (list type:float type:float type:float type:float) (list type:vec4)))) + (define type:length + (let ((a (fresh-variable-type))) + (list (type-scheme + (list a) + (qualified-type + (function-type (list a) (list type:float)) + (predicate:or + (predicate:= a type:float) + (predicate:= a type:vec2) + (predicate:= a type:vec3) + (predicate:= a type:vec4))))))) (define type:abs (let ((a (fresh-variable-type))) (list (type-scheme @@ -2157,6 +2186,7 @@ (vec2 . ,type:make-vec2) (vec3 . ,type:make-vec3) (vec4 . ,type:make-vec4) + (length . ,type:length) (abs . ,type:abs) (sqrt . ,type:sqrt) (min . ,type:min/max) @@ -2236,14 +2266,25 @@ (define (vars->subs exp env) (match exp (('t ((? variable-type? tvar)) (? symbol? name)) - (list (cons tvar (lookup name env)))) + (let ((type (lookup* name env))) + (if type + (list (cons tvar type)) + '()))) ((head . rest) (delete-duplicates (append (vars->subs head env) (vars->subs rest env)))) (_ '()))) -(define (resolve-overloads program) +(define (untype x) + (match x + (('t (_ ...) exp) + (untype exp)) + ((exp . rest) + (cons (untype exp) (untype rest))) + (_ x))) + +(define (resolve-overloads program stage) ;; Find all of the struct types used in the program. They will be ;; used to generate overloaded functions that take one or more ;; structs as arguments. @@ -2251,7 +2292,8 @@ (match program (('t types ('top-level bindings body)) (define bindings* - (let loop ((bindings bindings)) + (let loop ((bindings bindings) + (globals (empty-env))) (match bindings (() '()) ((('function name ('t ((? type-scheme? type)) func)) . rest) @@ -2267,8 +2309,16 @@ (('lambda (params ...) _) params))) (define env - (fold extend-env (empty-env) params - (function-type-parameters type*))) + (compose-envs (fold extend-env (empty-env) params + (map list (function-type-parameters type*))) + globals)) + (pk 're-infer-func env) + (match func + (('lambda _ body) + (pretty-print + (infer-exp (untype body) + (compose-envs env + (top-level-type-env stage)))))) (define subs* (compose-substitutions subs (vars->subs func env))) @@ -2283,9 +2333,17 @@ ;; rest) ;; (find-signatures name body)) )) - (loop rest))) - ((binding . rest) - (cons binding (loop rest)))))) + (loop rest globals))) + ((('function name texp) . rest) + (cons `(function ,name ,texp) + (loop rest globals))) + (((qualifier type name) . rest) + (pk 'global qualifier type name) + (cons (list qualifier type name) + (loop rest + (extend-env name + (list (type-descriptor->type type)) + globals))))))) `(t ,types (top-level ,bindings* ,body))) (_ (error "expected top-level form" program)))) @@ -2330,43 +2388,43 @@ (when (> n 0) (display (make-string (* n 2) #\space) port))) -(define (emit:int n version port level) +(define (emit:int n stage version port level) (define temp (unique-identifier)) (indent level port) (format port "int ~a = ~a;\n" temp n) (list temp)) -(define (emit:float n version port level) +(define (emit:float n stage version port level) (define temp (unique-identifier)) (indent level port) (format port "float ~a = ~a;\n" temp n) (list temp)) -(define (emit:boolean b version port level) +(define (emit:boolean b stage version port level) (define temp (unique-identifier)) (indent level port) (format port "bool ~a = ~a;\n" temp (if b "true" "false")) (list temp)) -(define (emit:binary-operator type op a b version port level) +(define (emit:binary-operator type op a b stage version port level) (define op* (case op ((=) '==) (else op))) - (define a-temp (single-temp (emit-glsl a version port level))) - (define b-temp (single-temp (emit-glsl b version port level))) + (define a-temp (single-temp (emit-glsl a stage version port level))) + (define b-temp (single-temp (emit-glsl b stage version port level))) (define temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a ~a ~a;\n" (type->glsl type) temp a-temp op* b-temp) (list temp)) -(define (emit:unary-operator type op a version port level) +(define (emit:unary-operator type op a stage version port level) (define op* (case op ((not) '!) (else op))) - (define a-temp (single-temp (emit-glsl a version port level))) + (define a-temp (single-temp (emit-glsl a stage version port level))) (define temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a(~a);\n" @@ -2387,10 +2445,11 @@ types lhs-list rhs-list*)) (define (emit:mov a b port level) - (indent level port) - (format port "~a = ~a;\n" a b)) + (when a + (indent level port) + (format port "~a = ~a;\n" a b))) -(define (emit:function name type params body version port level) +(define (emit:function name type params body stage version port level) (define param-types (function-type-parameters type)) (define return-types (function-type-returns type)) (define outputs (unique-identifiers-for-list return-types)) @@ -2412,29 +2471,32 @@ qualifier (type->glsl type) name) (loop rest #f)))) (display ") {\n" port) - (define body-temps (emit-glsl body version port (+ level 1))) + (define body-temps (emit-glsl body stage version port (+ level 1))) (for-each (lambda (output temp) (emit:mov output temp port (+ level 1))) outputs body-temps) (indent level port) (display "}\n" port)) -(define (emit:if predicate consequent alternate version port level) - (define if-temps (unique-identifiers-for-list (texp-types consequent))) +(define (emit:if predicate consequent alternate stage version port level) + (define if-temps + (if (equal? (texp-types consequent) (list type:outputs)) + '(#f) + (unique-identifiers-for-list (texp-types consequent)))) (emit:declarations (texp-types consequent) if-temps #f port level) (define predicate-temp - (single-temp (emit-glsl predicate version port level))) + (single-temp (emit-glsl predicate stage version port level))) (indent level port) (format port "if(~a) {\n" predicate-temp) (define consequent-temps - (emit-glsl consequent version port (+ level 1))) + (emit-glsl consequent stage version port (+ level 1))) (for-each (lambda (lhs rhs) (emit:mov lhs rhs port (+ level 1))) if-temps consequent-temps) (indent level port) (display "} else {\n" port) (define alternate-temps - (emit-glsl alternate version port (+ level 1))) + (emit-glsl alternate stage version port (+ level 1))) (for-each (lambda (lhs rhs) (emit:mov lhs rhs port (+ level 1))) if-temps alternate-temps) @@ -2442,19 +2504,19 @@ (display "}\n" port) if-temps) -(define (emit:values exps version port level) +(define (emit:values exps stage version port level) (append-map (lambda (exp) - (emit-glsl exp version port level)) + (emit-glsl exp stage version port level)) exps)) -(define (emit:let types names exps body version port level) +(define (emit:let types names exps body stage version port level) (define binding-temps (map (lambda (exp) - (single-temp (emit-glsl exp version port level))) + (single-temp (emit-glsl exp stage version port level))) exps)) (define binding-types (map single-type exps)) (emit:declarations binding-types names binding-temps port level) - (define body-temps (emit-glsl body version port level)) + (define body-temps (emit-glsl body stage version port level)) (define let-temps (unique-identifiers-for-list types)) (emit:declarations (texp-types body) let-temps body-temps port level) let-temps) @@ -2464,12 +2526,12 @@ (int->float . float) (texture-2d . texture2D))) -(define (emit:primcall type operator args version port level) +(define (emit:primcall type operator args stage version port level) (define operator* (or (assq-ref %primcall-map operator) operator)) (define arg-temps (map (lambda (arg) - (single-temp (emit-glsl arg version port level))) + (single-temp (emit-glsl arg stage version port level))) args)) (define output-temp (unique-identifier)) (indent level port) @@ -2480,11 +2542,11 @@ (string-join (map symbol->string arg-temps) ", ")) (list output-temp)) -(define (emit:call types operator args version port level) - (define operator-name (single-temp (emit-glsl operator version port))) +(define (emit:call types operator args stage version port level) + (define operator-name (single-temp (emit-glsl operator stage version port))) (define arg-temps (map (lambda (arg) - (single-temp (emit-glsl arg version port level))) + (single-temp (emit-glsl arg stage version port level))) args)) (define output-temps (unique-identifiers-for-list types)) (emit:declarations types output-temps #f port level) @@ -2495,8 +2557,8 @@ ", ")) output-temps) -(define (emit:struct-ref type exp field version port level) - (define input-temp (single-temp (emit-glsl exp version port level))) +(define (emit:struct-ref type exp field stage version port level) + (define input-temp (single-temp (emit-glsl exp stage version port level))) (define output-temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a.~a;\n" @@ -2506,9 +2568,9 @@ field) (list output-temp)) -(define (emit:array-ref type array-exp index-exp version port level) - (define array-temp (single-temp (emit-glsl array-exp version port level))) - (define index-temp (single-temp (emit-glsl index-exp version port level))) +(define (emit:array-ref type array-exp index-exp stage version port level) + (define array-temp (single-temp (emit-glsl array-exp stage version port level))) + (define index-temp (single-temp (emit-glsl index-exp stage version port level))) (define output-temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a[~a];\n" @@ -2518,7 +2580,7 @@ index-temp) (list output-temp)) -(define (emit:top-level bindings body version port level) +(define (emit:top-level bindings body stage version port level) (for-each (match-lambda (((? top-level-qualifier? qualifier) type-desc name) (format port "~a ~a ~a;\n" @@ -2526,10 +2588,10 @@ (type-descriptor->glsl type-desc) name)) (('function name ('t (type) ('lambda params body))) - (emit:function name type params body version port level))) + (emit:function name type params body stage version port level))) bindings) (display "void main() {\n" port) - (emit-glsl body version port (+ level 1)) + (emit-glsl body stage version port (+ level 1)) (display "}\n" port)) (define %built-in-output-map @@ -2538,51 +2600,55 @@ (vertex:clip-distance . gl_ClipDistance) (fragment:depth . gl_FragDepth))) -(define (emit:outputs names exps version port level) +(define (emit:outputs names exps stage version port level) (define (output-name name) (or (assq-ref %built-in-output-map name) name)) - (for-each (lambda (name exp) - (match (emit-glsl exp version port level) - ((temp) - (indent level port) - (format port "~a = ~a;\n" - (output-name name) - temp)))) - names exps) + (if (and (eq? stage 'fragment) (null? names)) + (begin + (indent level port) + (format port "discard;\n")) + (for-each (lambda (name exp) + (match (emit-glsl exp stage version port level) + ((temp) + (indent level port) + (format port "~a = ~a;\n" + (output-name name) + temp)))) + names exps)) '(#f)) -(define* (emit-glsl exp version port #:optional (level 0)) +(define* (emit-glsl exp stage version port #:optional (level 0)) (match exp (('t _ (? exact-integer? n)) - (emit:int n version port level)) + (emit:int n stage version port level)) (('t _ (? float? n)) - (emit:float n version port level)) + (emit:float n stage version port level)) (('t _ (? boolean? b)) - (emit:boolean b version port level)) + (emit:boolean b stage version port level)) (('t _ (? symbol? var)) (list var)) (('t _ ('if predicate consequent alternate)) - (emit:if predicate consequent alternate version port level)) + (emit:if predicate consequent alternate stage version port level)) (('t _ ('values exps ...)) - (emit:values exps version port level)) + (emit:values exps stage version port level)) (('t types ('let ((names exps) ...) body)) - (emit:let types names exps body version port level)) + (emit:let types names exps body stage version port level)) (('t (type) ('primcall (? binary-operator? op) a b)) - (emit:binary-operator type op a b version port level)) + (emit:binary-operator type op a b stage version port level)) (('t (type) ('primcall (? unary-operator? op) a)) - (emit:unary-operator type op a version port level)) + (emit:unary-operator type op a stage version port level)) (('t (type) ('primcall op args ...)) - (emit:primcall type op args version port level)) + (emit:primcall type op args stage version port level)) (('t types ('call operator args ...)) - (emit:call types operator args version port level)) + (emit:call types operator args stage version port level)) (('t (type) ('struct-ref exp field)) - (emit:struct-ref type exp field version port level)) + (emit:struct-ref type exp field stage version port level)) (('t (type) ('array-ref array-exp index-exp)) - (emit:array-ref type array-exp index-exp version port level)) + (emit:array-ref type array-exp index-exp stage version port level)) (('t _ ('outputs (names exps) ...)) - (emit:outputs names exps version port level)) + (emit:outputs names exps stage version port level)) (('t _ ('top-level (bindings ...) body)) - (emit:top-level bindings body version port level)))) + (emit:top-level bindings body stage version port level)))) ;;; @@ -2657,7 +2723,7 @@ (let* ((propagated (propagate-constants expanded (empty-env))) (hoisted (hoist-functions* propagated)) (inferred (infer-types hoisted stage)) - (resolved (resolve-overloads inferred))) + (resolved (resolve-overloads inferred stage))) (values resolved global-map (unique-identifier-counter)))))) (define (specs->globals specs) @@ -2666,6 +2732,7 @@ (make-seagull-global qualifier type-desc name))) specs)) +(use-modules (ice-9 pretty-print)) ;; Using syntax-case allows us to compile shaders to their fully typed ;; intermediate form at compile time, leaving only GLSL emission for ;; runtime. @@ -2685,6 +2752,7 @@ #:inputs inputs #:outputs outputs #:uniforms uniforms)) + (pretty-print compiled) (with-syntax ((inputs (datum->syntax x inputs)) (outputs (datum->syntax x outputs)) (uniforms (datum->syntax x uniforms)) @@ -2784,8 +2852,17 @@ (fragment-uniform-map (map-globals (seagull-module-uniforms fragment) fragment-global-map)) ;; Give new names to the fragment uniforms so that the names - ;; do not clash with vertex globals. - (fragment-uniform-alpha-map (alpha-rename fragment-uniform-map)) + ;; do not clash with vertex globals and also that any + ;; uniforms in the vertex shader have the *same* name in the + ;; fragment shader. + (fragment-uniform-alpha-map + (map (match-lambda + ((original-name . alpha-name) + (cons alpha-name + (or (assq-ref vertex-uniform-alpha-map + (assq-ref vertex-uniform-map original-name)) + (unique-identifier))))) + fragment-uniform-map)) ;; This one is a little messy but what's happening is that ;; the GLSL name for each fragment output is mapped to the ;; respective renamed input. Vertex shader output names must @@ -2843,11 +2920,11 @@ (define vertex-glsl (call-with-output-string (lambda (port) - (emit-glsl vertex* version port)))) + (emit-glsl vertex* 'fragment version port)))) (define fragment-glsl (call-with-output-string (lambda (port) - (emit-glsl fragment* version port)))) + (emit-glsl fragment* 'fragment version port)))) (display vertex-glsl) (newline) (display fragment-glsl) |