From ae40818c8b70d037c8f9d0456748afd7a3fd6bc1 Mon Sep 17 00:00:00 2001 From: David Thompson Date: Tue, 24 Jan 2023 20:15:27 -0500 Subject: Add start of vertex/fragment shader distinction. --- chickadee/graphics/seagull.scm | 316 ++++++++++++++++++++++------------------- 1 file changed, 167 insertions(+), 149 deletions(-) diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm index 0a54b88..9f5aaeb 100644 --- a/chickadee/graphics/seagull.scm +++ b/chickadee/graphics/seagull.scm @@ -84,12 +84,23 @@ (define (math-function? x) (memq x '(abs sqrt min max sin cos tan clamp mix))) -(define (primitive-call? x) +(define (vertex-primitive-call? x) + #f) + +(define (fragment-primitive-call? x) + (memq x '(texture-2d))) + +(define (primitive-call? x stage) (or (binary-operator? x) (unary-operator? x) (vector-constructor? x) (conversion? x) - (math-function? x))) + (math-function? x) + (case stage + ((vertex) + (vertex-primitive-call? x)) + ((fragment) + (fragment-primitive-call? x))))) (define (top-level-qualifier? x) (memq x '(in out uniform))) @@ -218,106 +229,108 @@ (define names* (map (lambda (_name) (unique-identifier)) names)) (fold extend-env (empty-env) names names*)) -(define (expand:list exps env) - (map (lambda (exp) (expand exp env)) exps)) +(define (expand:list exps stage env) + (map (lambda (exp) (expand exp stage env)) exps)) -(define (expand:variable exp env) +(define (expand:variable exp stage env) ;; Replace original variable with alpha-converted name, but keep ;; track of the original for showing the user error messages that ;; make sense later. `(var ,(lookup exp env) ,exp)) -(define (expand:if predicate consequent alternate env) - `(if ,(expand predicate env) - ,(expand consequent env) - ,(expand alternate env))) +(define (expand:if predicate consequent alternate stage env) + `(if ,(expand predicate stage env) + ,(expand consequent stage env) + ,(expand alternate stage env))) -(define (expand:let names exps body env) +(define (expand:let names exps body stage env) (if (null? names) - (expand body env) - (let* ((exps* (map (lambda (exp) (expand exp env)) exps)) + (expand body stage env) + (let* ((exps* (map (lambda (exp) (expand exp stage env)) exps)) (env* (compose-envs (alpha-convert names) env)) (bindings* (map list (lookup-all names env*) exps*))) - `(let ,bindings* ,(expand body env*))))) + `(let ,bindings* ,(expand body stage env*))))) -(define (expand:let* bindings body env) +(define (expand:let* bindings body stage env) (match bindings - (() (expand body env)) + (() (expand body stage env)) ((binding . rest) (expand `(let (,binding) (let* ,rest ,body)) + stage env)))) -(define (expand:lambda params body env) +(define (expand:lambda params body stage env) (define env* (compose-envs (alpha-convert params) env)) (define params* (lookup-all params env*)) - `(lambda ,params* ,(expand body env*))) + `(lambda ,params* ,(expand body stage env*))) -(define (expand:values exps env) - `(values ,@(expand:list exps env))) +(define (expand:values exps stage env) + `(values ,@(expand:list exps stage env))) ;; Arithmetic operators, in true Lisp fashion, can accept many ;; arguments. + and * accept 0 or more. - and / accept one or more. ;; The expansion pass transforms all such expressions into binary ;; operator form. -(define (expand:+ args env) +(define (expand:+ args stage env) (match args (() 0) - ((n) (expand n env)) + ((n) (expand n stage env)) ((n . rest) - `(primcall + ,(expand n env) ,(expand:+ rest env))))) + `(primcall + ,(expand n stage env) ,(expand:+ rest stage env))))) -(define (expand:- args env) +(define (expand:- args stage env) (match args - ((n) `(primcall - ,(expand n env) 0)) + ((n) `(primcall - ,(expand n stage env) 0)) ((m n) - `(primcall - ,(expand m env) ,(expand n env))) + `(primcall - ,(expand m stage env) ,(expand n stage env))) ((n . rest) - `(primcall - ,(expand n env) ,(expand:- rest env))))) + `(primcall - ,(expand n stage env) ,(expand:- rest stage env))))) -(define (expand:* args env) +(define (expand:* args stage env) (match args (() 1) - ((n) (expand n env)) + ((n) (expand n stage env)) ((n . rest) - `(primcall * ,(expand n env) ,(expand:* rest env))))) + `(primcall * ,(expand n stage env) ,(expand:* rest stage env))))) -(define (expand:/ args env) +(define (expand:/ args stage env) (match args ((n) - `(primcall / 1 ,(expand n env))) + `(primcall / 1 ,(expand n stage env))) ((m n) - `(primcall / ,(expand m env) ,(expand n env))) + `(primcall / ,(expand m stage env) ,(expand n stage env))) ((m n . rest) (let loop ((rest rest) - (exp `(primcall / ,(expand m env) ,(expand n env)))) + (exp `(primcall / ,(expand m stage env) ,(expand n stage env)))) (match rest ((l) - `(primcall / ,exp ,(expand l env))) + `(primcall / ,exp ,(expand l stage env))) ((l . rest) - (loop rest `(primcall / ,exp ,(expand l env))))))))) + (loop rest `(primcall / ,exp ,(expand l stage env))))))))) -(define (expand:primitive-call operator operands env) - `(primcall ,operator ,@(expand:list operands env))) +(define (expand:primitive-call operator operands stage env) + `(primcall ,operator ,@(expand:list operands stage env))) -(define (expand:call operator operands env) - `(call ,(expand operator env) ,@(expand:list operands env))) +(define (expand:call operator operands stage env) + `(call ,(expand operator stage env) ,@(expand:list operands stage env))) -(define (expand:top-level qualifiers types names body env) +(define (expand:top-level qualifiers types names body stage env) (let* ((env* (compose-envs (alpha-convert names) env))) ;; TODO: Support interpolation qualifiers. `(top-level ,(map (lambda (qualifier type name) (list qualifier type (lookup name env*))) qualifiers types names) - ,(expand body env*)))) + ,(expand body stage env*)))) -(define (expand:outputs names exps env) - `(outputs ,@(map (lambda (name exp) - (list (if (built-in-output? name 'vertex) - name - (lookup name env)) - (expand exp env))) - names exps))) +(define (expand:outputs names exps stage env) + `(outputs + ,@(map (lambda (name exp) + (list (if (built-in-output? name (current-shader-stage)) + name + (lookup name env)) + (expand exp stage env))) + names exps))) (define &seagull-syntax-error (make-exception-type '&seagull-syntax-error &error '(form))) @@ -329,44 +342,46 @@ (exception-accessor &seagull-syntax-error (record-accessor &seagull-syntax-error 'form))) -(define (expand exp env) +(define (expand exp stage env) + (define (primitive-call-for-stage? x) + (primitive-call? x stage)) (match exp ;; Immediates and variables: ((? immediate?) exp) ((? symbol?) - (expand:variable exp env)) + (expand:variable exp stage env)) ;; Primitive syntax forms: (('if predicate consequent alternate) - (expand:if predicate consequent alternate env)) + (expand:if predicate consequent alternate stage env)) (('let (((? symbol? names) exps) ...) body) - (expand:let names exps body env)) + (expand:let names exps body stage env)) (('lambda ((? symbol? params) ...) body) - (expand:lambda params body env)) + (expand:lambda params body stage env)) (('values exps ...) - (expand:values exps env)) + (expand:values exps stage env)) (('outputs ((? symbol? names) exps) ...) - (expand:outputs names exps env)) + (expand:outputs names exps stage env)) (('top-level (((? top-level-qualifier? qualifiers) types names) ...) body) - (expand:top-level qualifiers types names body env)) + (expand:top-level qualifiers types names body stage env)) ;; Macros: (('let* (bindings ...) body) - (expand:let* bindings body env)) + (expand:let* bindings body stage env)) (('+ args ...) - (expand:+ args env)) + (expand:+ args stage env)) (('- args ...) - (expand:- args env)) + (expand:- args stage env)) (('* args ...) - (expand:* args env)) + (expand:* args stage env)) (('/ args ...) - (expand:/ args env)) + (expand:/ args stage env)) ;; Primitive calls: - (((? primitive-call? operator) args ...) - (expand:primitive-call operator args env)) + (((? primitive-call-for-stage? operator) args ...) + (expand:primitive-call operator args stage env)) ;; Function calls: ((operator args ...) - (expand:call operator args env)) + (expand:call operator args stage env)) ;; Syntax error: (_ (raise-exception @@ -695,6 +710,7 @@ (define vec4-type (primitive-type 'vec4)) (define mat3-type (primitive-type 'mat3)) (define mat4-type (primitive-type 'mat4)) +(define sampler-2d-type (primitive-type 'sampler-2d)) (define (type-name->type name) (case name @@ -705,7 +721,8 @@ ((vec3) vec3-type) ((vec4) vec4-type) ((mat3) mat3-type) - ((mat4) mat4-type))) + ((mat4) mat4-type) + ((sampler-2d) sampler-2d-type))) ;; Type variables: (define unique-type-variable-counter (make-parameter 0)) @@ -930,7 +947,7 @@ (function-type (list float-type float-type float-type) (list float-type))))) -(define (top-level-type-env) +(define (top-level-type-env stage) `((+ . ,add/sub-type) (- . ,add/sub-type) (* . ,mul-type) @@ -955,7 +972,12 @@ (tan . ,trigonometry-type) (clamp . ,clamp/mix-type) (mix . ,clamp/mix-type) - (gl-position ,vec4-type))) + ,@(case stage + ((vertex) + `((gl-position ,vec4-type))) + ((fragment) + `((texture-2d ,(function-type (list sampler-2d-type vec2-type) + (list vec4-type)))))))) (define (occurs? a b) (cond @@ -969,7 +991,6 @@ (else #f))) (define (apply-substitution-to-type type from to) - (pk 'apply-substitution-to-type type from to) (cond ((or (primitive-type? type) (outputs-type? type)) @@ -1169,7 +1190,7 @@ (texp (list outputs-type) `(outputs ,@(map (lambda (name exp) - (list (texp (pk 'types name (lookup name env)) name) + (list (texp (lookup name env) name) (annotate-exp exp env))) names exps)))) @@ -1200,8 +1221,8 @@ (('top-level bindings body) (annotate:top-level bindings body env)))) -(define (annotate-exp* exp) - (annotate-exp exp (top-level-type-env))) +(define (annotate-exp* exp stage) + (annotate-exp exp (top-level-type-env stage))) ;;; @@ -1235,17 +1256,14 @@ (apply handler args)))) (define (unify:fail . args) - (pk 'unify:fail args) (abort-to-prompt unify-prompt-tag args)) (define (unify:primitives a b success) - (pk 'unify:primitives a b) (if (equal? a b) (success '()) (unify:fail "type mismatch" a b))) (define (unify:variable a b success) - (pk 'unify:variable a b) (cond ((eq? a b) (success '())) @@ -1255,7 +1273,6 @@ (success (list (cons a b)))))) (define (unify:functions a b success) - (pk 'unify:functions a b) (unify (function-type-parameters a) (function-type-parameters b) (lambda (sub0) @@ -1267,13 +1284,11 @@ (success (compose-substitutions sub0 sub1))))))) (define (unify:overload a b success) - (pk 'unify:overload a b) (define (try-unify functions) (match functions (() (unify:fail "no matching overload" a b)) ((function . rest) - (pk 'try-overload function b) (call-with-unify-rollback (lambda () (unify function b success)) @@ -1282,7 +1297,6 @@ (try-unify (overload-type-ref a))) (define (unify:lists a rest-a b rest-b success) - (pk 'unify:lists a rest-a b rest-b) (unify a b (lambda (sub0) (unify (apply-substitutions-to-types rest-a sub0) @@ -1291,7 +1305,6 @@ (success (compose-substitutions sub0 sub1))))))) (define (unify a b success) - (pk 'unify a b) (match (list a b) (((? primitive-type? a) (? primitive-type? b)) (unify:primitives a b success)) @@ -1494,10 +1507,10 @@ (resolve:list exps env)) (_ exp))) -(define (infer-types exp) +(define (infer-types exp stage) (call-with-unify-rollback (lambda () - (let ((annotated (pk 'annotated (annotate-exp* exp)))) + (let ((annotated (annotate-exp* exp stage))) (infer annotated '() (lambda (subs) @@ -1523,43 +1536,43 @@ (when (> n 0) (display (make-string (* n 2) #\space) port))) -(define (emit:int n port level) +(define (emit:int n version port level) (define temp (unique-identifier)) (indent level port) (format port "int ~a = ~a;\n" temp n) (list temp)) -(define (emit:float n port level) +(define (emit:float n version port level) (define temp (unique-identifier)) (indent level port) (format port "float ~a = ~a;\n" temp n) (list temp)) -(define (emit:boolean b port level) +(define (emit:boolean b version port level) (define temp (unique-identifier)) (indent level port) (format port "bool ~a = ~a;\n" temp (if b "true" "false")) (list temp)) -(define (emit:binary-operator type op a b port level) +(define (emit:binary-operator type op a b version port level) (define op* (case op ((=) '==) (else op))) - (define a-temp (single-temp (emit-glsl a port level))) - (define b-temp (single-temp (emit-glsl b port level))) + (define a-temp (single-temp (emit-glsl a version port level))) + (define b-temp (single-temp (emit-glsl b version port level))) (define temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a ~a ~a;\n" (type-name type) temp a-temp op* b-temp) (list temp)) -(define (emit:unary-operator type op a port level) +(define (emit:unary-operator type op a version port level) (define op* (case op ((not) '!) (else op))) - (define a-temp (single-temp (emit-glsl a port level))) + (define a-temp (single-temp (emit-glsl a version port level))) (define temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a(~a);\n" @@ -1582,7 +1595,7 @@ (indent level port) (format port "~a = ~a;\n" a b)) -(define (emit:function name type params body port level) +(define (emit:function name type params body version port level) (define param-types (function-type-parameters type)) (define return-types (function-type-returns type)) (define outputs (unique-identifiers-for-list return-types)) @@ -1604,26 +1617,29 @@ qualifier (type-name type) name) (loop rest #f)))) (display ") {\n" port) - (define body-temps (emit-glsl body port (+ level 1))) + (define body-temps (emit-glsl body version port (+ level 1))) (for-each (lambda (output temp) (emit:mov output temp port (+ level 1))) outputs body-temps) (indent level port) (display "}\n" port)) -(define (emit:if predicate consequent alternate port level) +(define (emit:if predicate consequent alternate version port level) (define if-temps (unique-identifiers-for-list (texp-types consequent))) (emit:declarations (texp-types consequent) if-temps #f port level) - (define predicate-temp (single-temp (emit-glsl predicate port level))) + (define predicate-temp + (single-temp (emit-glsl predicate version port level))) (indent level port) (format port "if(~a) {\n" predicate-temp) - (define consequent-temps (emit-glsl consequent port (+ level 1))) + (define consequent-temps + (emit-glsl consequent version port (+ level 1))) (for-each (lambda (lhs rhs) (emit:mov lhs rhs port (+ level 1))) if-temps consequent-temps) (indent level port) (display "else {\n" port) - (define alternate-temps (emit-glsl alternate port (+ level 1))) + (define alternate-temps + (emit-glsl alternate version port (+ level 1))) (for-each (lambda (lhs rhs) (emit:mov lhs rhs port (+ level 1))) if-temps alternate-temps) @@ -1631,32 +1647,34 @@ (display "}\n" port) if-temps) -(define (emit:values exps port level) +(define (emit:values exps version port level) (append-map (lambda (exp) - (emit-glsl exp port level)) + (emit-glsl exp version port level)) exps)) -(define (emit:let types names exps body port level) +(define (emit:let types names exps body version port level) (define binding-temps (map (lambda (exp) - (single-temp (emit-glsl exp port level))) + (single-temp (emit-glsl exp version port level))) exps)) (define binding-types (map single-type exps)) (emit:declarations binding-types names binding-temps port level) - (define body-temps (emit-glsl body port level)) + (define body-temps (emit-glsl body version port level)) (define let-temps (unique-identifiers-for-list types)) (emit:declarations (texp-types body) let-temps body-temps port level) let-temps) -(define (emit:primcall type operator args port level) +(define %primcall-map + '((float->int . float) + (int->float . int) + (texture-2d . texture2D))) + +(define (emit:primcall type operator args version port level) (define operator* - (case operator - ((float->int) 'float) - ((int->float) 'int) - (else operator))) + (or (assq-ref %primcall-map operator) operator)) (define arg-temps (map (lambda (arg) - (single-temp (emit-glsl arg port level))) + (single-temp (emit-glsl arg version port level))) args)) (define output-temp (unique-identifier)) (indent level port) @@ -1667,11 +1685,11 @@ (string-join (map symbol->string arg-temps) ", ")) (list output-temp)) -(define (emit:call types operator args port level) - (define operator-name (single-temp (emit-glsl operator port))) +(define (emit:call types operator args version port level) + (define operator-name (single-temp (emit-glsl operator version port))) (define arg-temps (map (lambda (arg) - (single-temp (emit-glsl arg port level))) + (single-temp (emit-glsl arg version port level))) args)) (define output-temps (unique-identifiers-for-list types)) (emit:declarations types output-temps #f port level) @@ -1682,28 +1700,34 @@ ", ")) output-temps) -(define (emit:top-level bindings body port level) +(define %type-name-map + '((sampler-2d . sampler2D))) + +(define (emit:top-level bindings body version port level) (for-each (match-lambda (((? top-level-qualifier? qualifier) type-name name) - (format port "~a ~a ~a;\n" qualifier type-name name)) + (define type-name* + (or (assq-ref %type-name-map type-name) type-name)) + (format port "~a ~a ~a;\n" qualifier type-name* name)) (('function name ('t (type) ('lambda params body))) - (emit:function name type params body port level))) + (emit:function name type params body version port level))) bindings) (display "void main() {\n" port) - (emit-glsl body port (+ level 1)) + (emit-glsl body version port (+ level 1)) (display "}\n" port)) -(define (emit:outputs names exps port level) +(define %built-in-output-map + '((gl-position . gl_Position) + (gl-point-size . gl_PointSize) + (gl-clip-distance . gl_ClipDistance) + (gl-frag-depth . gl_FragDepth) + (gl-sample-mask . gl_SampleMask))) + +(define (emit:outputs names exps version port level) (define (output-name name) - (case name - ((gl-position) 'gl_Position) - ((gl-point-size) 'gl_PointSize) - ((gl-clip-distance) 'gl_ClipDistance) - ((gl-frag-depth) 'gl_FragDepth) - ((gl-sample-mask) 'gl_SampleMask) - (else name))) + (or (assq-ref %built-in-output-map name) name)) (for-each (lambda (name exp) - (match (emit-glsl exp port level) + (match (emit-glsl exp version port level) ((temp) (indent level port) (format port "~a = ~a;\n" @@ -1711,34 +1735,34 @@ temp)))) names exps)) -(define* (emit-glsl exp port #:optional (level 0)) +(define* (emit-glsl exp version port #:optional (level 0)) (match exp (('t _ (? exact-integer? n)) - (emit:int n port level)) + (emit:int n version port level)) (('t _ (? float? n)) - (emit:float n port level)) + (emit:float n version port level)) (('t _ (? boolean? b)) - (emit:boolean b port level)) + (emit:boolean b version port level)) (('t _ ('var var _)) (list var)) (('t _ ('if predicate consequent alternate)) - (emit:if predicate consequent alternate port level)) + (emit:if predicate consequent alternate version port level)) (('t _ ('values exps ...)) - (emit:values exps port level)) + (emit:values exps version port level)) (('t types ('let ((names exps) ...) body)) - (emit:let types names exps body port level)) + (emit:let types names exps body version port level)) (('t (type) ('primcall ('t _ (? binary-operator? op)) a b)) - (emit:binary-operator type op a b port level)) + (emit:binary-operator type op a b version port level)) (('t (type) ('primcall ('t _ (? unary-operator? op)) a)) - (emit:unary-operator type op a port level)) + (emit:unary-operator type op a version port level)) (('t (type) ('primcall ('t _ op) args ...)) - (emit:primcall type op args port level)) + (emit:primcall type op args version port level)) (('t types ('call operator args ...)) - (emit:call types operator args port level)) + (emit:call types operator args version port level)) (('t _ ('outputs (names exps) ...)) - (emit:outputs names exps port level)) + (emit:outputs names exps version port level)) (('t _ ('top-level (bindings ...) body)) - (emit:top-level bindings body port level)))) + (emit:top-level bindings body version port level)))) ;;; @@ -1748,19 +1772,13 @@ ;; Combine all of the compiler passes on a user provided program and ;; emit GLSL code if the program is valid. -(define (print-list lst) - (for-each (lambda (x) (format #t "~a\n" x)) lst)) - -;; Substitutions aren't being generated correctly. -(define* (compile-seagull exp #:optional (port (current-output-port))) +(define* (compile-seagull exp #:key (stage 'vertex) + (version '330) + (port (current-output-port))) (parameterize ((unique-identifier-counter 0) (unique-type-variable-counter 0)) - (let* ((expanded (pk 'expanded (expand exp (top-level-env)))) - (propagated (pk 'propagated (propagate-constants expanded (empty-env)))) - (hoisted (pk 'hoisted (hoist-functions* propagated))) - (inferred (pk 'inferred (infer-types hoisted)))) - (display "*** BEGIN GLSL OUTPUT ***\n" port) - (emit-glsl inferred port) - (newline port) - (display "*** END GLSL OUTPUT ***\n" port) - inferred))) + (let* ((expanded (expand exp stage (top-level-env))) + (propagated (propagate-constants expanded (empty-env))) + (hoisted (hoist-functions* propagated)) + (inferred (infer-types hoisted stage))) + (emit-glsl inferred version port)))) -- cgit v1.2.3