summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chickadee/graphics/seagull.scm247
1 files changed, 162 insertions, 85 deletions
diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm
index 4e271ff..44a9bd9 100644
--- a/chickadee/graphics/seagull.scm
+++ b/chickadee/graphics/seagull.scm
@@ -39,12 +39,15 @@
;;
;; TODO:
;; - Loops
+;; - discard
;; - (define ...) form
+;; - struct field aliases (rgba for vec4, for example)
;; - Scheme shader type -> GLSL struct translation
;; - Dead code elimination (error when a uniform is eliminated)
;; - User defined structs
;; - Multiple GLSL versions
;; - Better error messages (especially around type predicate failure)
+;; - Refactor to add define-primitive syntax
;;
;;; Code:
(define-module (chickadee graphics seagull)
@@ -114,7 +117,7 @@
(memq x '(int->float float->int)))
(define (math-function? x)
- (memq x '(abs sqrt min max mod floor ceil sin cos tan clamp mix)))
+ (memq x '(abs sqrt min max mod floor ceil sin cos tan clamp mix length)))
(define (vertex-primitive-call? x)
#f)
@@ -1235,8 +1238,8 @@
,(apply-substitution-to-type b from to)))
(('struct-field struct field var)
`(struct-field ,(apply-substitution-to-type struct from to)
- ,field
- ,(apply-substitution-to-type var from to)))
+ ,field
+ ,(apply-substitution-to-type var from to)))
(('array-element array var)
`(array-element ,(apply-substitution-to-type array from to)
,(apply-substitution-to-type var from to)))))
@@ -1405,7 +1408,10 @@
;; struct type.
(('struct-field struct field field-var)
(if (struct-type? struct)
- (values #t (list (cons field-var (struct-type-ref struct field))))
+ (let ((field-type (struct-type-ref struct field)))
+ (if field-type
+ (values #t (list (cons field-var field-type)))
+ (values #f '())))
(values pred '())))
;; Substitute the element var when array has been resolved to an
;; array type.
@@ -1774,11 +1780,15 @@
(infer-exp exp env))
(define exp-type (single-type exp*))
(define tvar (fresh-variable-type))
- (values (texp (list tvar)
- `(struct-ref ,exp* ,field))
- exp-subs
- (compose-predicates exp-pred
- (predicate:struct-field exp-type field tvar))))
+ (define-values (pred combined-subs)
+ (eval-predicate* (compose-predicates (predicate:struct-field exp-type field tvar)
+ exp-pred)
+ exp-subs))
+ (values (texp (list (apply-substitutions-to-type tvar combined-subs))
+ `(struct-ref ,(apply-substitutions-to-texp exp* combined-subs)
+ ,field))
+ combined-subs
+ pred))
(define (infer:array-ref array-exp index-exp env)
(define-values (array-exp* array-exp-subs array-exp-pred)
@@ -1823,7 +1833,11 @@
(eval-predicate* (compose-predicates exp-pred body-pred)
(compose-substitutions exp-subs body-subs)))
(values (texp (texp-types body*)
- `(let ,(map list names exps*) ,body*))
+ `(let ,(map (lambda (name exp)
+ (list name (apply-substitutions-to-texp
+ exp combined-subs)))
+ names exps*)
+ ,(apply-substitutions-to-texp body* combined-subs)))
combined-subs
pred))
@@ -1848,7 +1862,11 @@
(define-values (pred combined-subs)
(eval-predicate* exp-pred (compose-substitutions exp-subs unify-subs)))
(values (texp (list type:outputs)
- `(outputs ,@(map list names exps*)))
+ `(outputs
+ ,@(map (lambda (name exp)
+ (list name (apply-substitutions-to-texp
+ exp combined-subs)))
+ names exps*)))
combined-subs
pred))
@@ -2025,8 +2043,8 @@
(type:vec4 type:float type:vec4)
(type:float type:vec4 type:vec4)
(type:mat3 type:mat3 type:mat3)
- (type:mat3 type:vec3 type:mat3)
- (type:vec3 type:mat3 type:mat3)
+ (type:mat3 type:vec3 type:vec3)
+ (type:vec3 type:mat3 type:vec3)
(type:mat4 type:mat4 type:mat4)
(type:mat4 type:vec4 type:vec4)
(type:vec4 type:mat4 type:vec4)))
@@ -2087,6 +2105,17 @@
(define type:make-vec4
(list (function-type (list type:float type:float type:float type:float)
(list type:vec4))))
+ (define type:length
+ (let ((a (fresh-variable-type)))
+ (list (type-scheme
+ (list a)
+ (qualified-type
+ (function-type (list a) (list type:float))
+ (predicate:or
+ (predicate:= a type:float)
+ (predicate:= a type:vec2)
+ (predicate:= a type:vec3)
+ (predicate:= a type:vec4)))))))
(define type:abs
(let ((a (fresh-variable-type)))
(list (type-scheme
@@ -2157,6 +2186,7 @@
(vec2 . ,type:make-vec2)
(vec3 . ,type:make-vec3)
(vec4 . ,type:make-vec4)
+ (length . ,type:length)
(abs . ,type:abs)
(sqrt . ,type:sqrt)
(min . ,type:min/max)
@@ -2236,14 +2266,25 @@
(define (vars->subs exp env)
(match exp
(('t ((? variable-type? tvar)) (? symbol? name))
- (list (cons tvar (lookup name env))))
+ (let ((type (lookup* name env)))
+ (if type
+ (list (cons tvar type))
+ '())))
((head . rest)
(delete-duplicates
(append (vars->subs head env)
(vars->subs rest env))))
(_ '())))
-(define (resolve-overloads program)
+(define (untype x)
+ (match x
+ (('t (_ ...) exp)
+ (untype exp))
+ ((exp . rest)
+ (cons (untype exp) (untype rest)))
+ (_ x)))
+
+(define (resolve-overloads program stage)
;; Find all of the struct types used in the program. They will be
;; used to generate overloaded functions that take one or more
;; structs as arguments.
@@ -2251,7 +2292,8 @@
(match program
(('t types ('top-level bindings body))
(define bindings*
- (let loop ((bindings bindings))
+ (let loop ((bindings bindings)
+ (globals (empty-env)))
(match bindings
(() '())
((('function name ('t ((? type-scheme? type)) func)) . rest)
@@ -2267,8 +2309,16 @@
(('lambda (params ...) _)
params)))
(define env
- (fold extend-env (empty-env) params
- (function-type-parameters type*)))
+ (compose-envs (fold extend-env (empty-env) params
+ (map list (function-type-parameters type*)))
+ globals))
+ (pk 're-infer-func env)
+ (match func
+ (('lambda _ body)
+ (pretty-print
+ (infer-exp (untype body)
+ (compose-envs env
+ (top-level-type-env stage))))))
(define subs*
(compose-substitutions subs
(vars->subs func env)))
@@ -2283,9 +2333,17 @@
;; rest)
;; (find-signatures name body))
))
- (loop rest)))
- ((binding . rest)
- (cons binding (loop rest))))))
+ (loop rest globals)))
+ ((('function name texp) . rest)
+ (cons `(function ,name ,texp)
+ (loop rest globals)))
+ (((qualifier type name) . rest)
+ (pk 'global qualifier type name)
+ (cons (list qualifier type name)
+ (loop rest
+ (extend-env name
+ (list (type-descriptor->type type))
+ globals)))))))
`(t ,types (top-level ,bindings* ,body)))
(_ (error "expected top-level form" program))))
@@ -2330,43 +2388,43 @@
(when (> n 0)
(display (make-string (* n 2) #\space) port)))
-(define (emit:int n version port level)
+(define (emit:int n stage version port level)
(define temp (unique-identifier))
(indent level port)
(format port "int ~a = ~a;\n" temp n)
(list temp))
-(define (emit:float n version port level)
+(define (emit:float n stage version port level)
(define temp (unique-identifier))
(indent level port)
(format port "float ~a = ~a;\n" temp n)
(list temp))
-(define (emit:boolean b version port level)
+(define (emit:boolean b stage version port level)
(define temp (unique-identifier))
(indent level port)
(format port "bool ~a = ~a;\n" temp (if b "true" "false"))
(list temp))
-(define (emit:binary-operator type op a b version port level)
+(define (emit:binary-operator type op a b stage version port level)
(define op*
(case op
((=) '==)
(else op)))
- (define a-temp (single-temp (emit-glsl a version port level)))
- (define b-temp (single-temp (emit-glsl b version port level)))
+ (define a-temp (single-temp (emit-glsl a stage version port level)))
+ (define b-temp (single-temp (emit-glsl b stage version port level)))
(define temp (unique-identifier))
(indent level port)
(format port "~a ~a = ~a ~a ~a;\n"
(type->glsl type) temp a-temp op* b-temp)
(list temp))
-(define (emit:unary-operator type op a version port level)
+(define (emit:unary-operator type op a stage version port level)
(define op*
(case op
((not) '!)
(else op)))
- (define a-temp (single-temp (emit-glsl a version port level)))
+ (define a-temp (single-temp (emit-glsl a stage version port level)))
(define temp (unique-identifier))
(indent level port)
(format port "~a ~a = ~a(~a);\n"
@@ -2387,10 +2445,11 @@
types lhs-list rhs-list*))
(define (emit:mov a b port level)
- (indent level port)
- (format port "~a = ~a;\n" a b))
+ (when a
+ (indent level port)
+ (format port "~a = ~a;\n" a b)))
-(define (emit:function name type params body version port level)
+(define (emit:function name type params body stage version port level)
(define param-types (function-type-parameters type))
(define return-types (function-type-returns type))
(define outputs (unique-identifiers-for-list return-types))
@@ -2412,29 +2471,32 @@
qualifier (type->glsl type) name)
(loop rest #f))))
(display ") {\n" port)
- (define body-temps (emit-glsl body version port (+ level 1)))
+ (define body-temps (emit-glsl body stage version port (+ level 1)))
(for-each (lambda (output temp)
(emit:mov output temp port (+ level 1)))
outputs body-temps)
(indent level port)
(display "}\n" port))
-(define (emit:if predicate consequent alternate version port level)
- (define if-temps (unique-identifiers-for-list (texp-types consequent)))
+(define (emit:if predicate consequent alternate stage version port level)
+ (define if-temps
+ (if (equal? (texp-types consequent) (list type:outputs))
+ '(#f)
+ (unique-identifiers-for-list (texp-types consequent))))
(emit:declarations (texp-types consequent) if-temps #f port level)
(define predicate-temp
- (single-temp (emit-glsl predicate version port level)))
+ (single-temp (emit-glsl predicate stage version port level)))
(indent level port)
(format port "if(~a) {\n" predicate-temp)
(define consequent-temps
- (emit-glsl consequent version port (+ level 1)))
+ (emit-glsl consequent stage version port (+ level 1)))
(for-each (lambda (lhs rhs)
(emit:mov lhs rhs port (+ level 1)))
if-temps consequent-temps)
(indent level port)
(display "} else {\n" port)
(define alternate-temps
- (emit-glsl alternate version port (+ level 1)))
+ (emit-glsl alternate stage version port (+ level 1)))
(for-each (lambda (lhs rhs)
(emit:mov lhs rhs port (+ level 1)))
if-temps alternate-temps)
@@ -2442,19 +2504,19 @@
(display "}\n" port)
if-temps)
-(define (emit:values exps version port level)
+(define (emit:values exps stage version port level)
(append-map (lambda (exp)
- (emit-glsl exp version port level))
+ (emit-glsl exp stage version port level))
exps))
-(define (emit:let types names exps body version port level)
+(define (emit:let types names exps body stage version port level)
(define binding-temps
(map (lambda (exp)
- (single-temp (emit-glsl exp version port level)))
+ (single-temp (emit-glsl exp stage version port level)))
exps))
(define binding-types (map single-type exps))
(emit:declarations binding-types names binding-temps port level)
- (define body-temps (emit-glsl body version port level))
+ (define body-temps (emit-glsl body stage version port level))
(define let-temps (unique-identifiers-for-list types))
(emit:declarations (texp-types body) let-temps body-temps port level)
let-temps)
@@ -2464,12 +2526,12 @@
(int->float . float)
(texture-2d . texture2D)))
-(define (emit:primcall type operator args version port level)
+(define (emit:primcall type operator args stage version port level)
(define operator*
(or (assq-ref %primcall-map operator) operator))
(define arg-temps
(map (lambda (arg)
- (single-temp (emit-glsl arg version port level)))
+ (single-temp (emit-glsl arg stage version port level)))
args))
(define output-temp (unique-identifier))
(indent level port)
@@ -2480,11 +2542,11 @@
(string-join (map symbol->string arg-temps) ", "))
(list output-temp))
-(define (emit:call types operator args version port level)
- (define operator-name (single-temp (emit-glsl operator version port)))
+(define (emit:call types operator args stage version port level)
+ (define operator-name (single-temp (emit-glsl operator stage version port)))
(define arg-temps
(map (lambda (arg)
- (single-temp (emit-glsl arg version port level)))
+ (single-temp (emit-glsl arg stage version port level)))
args))
(define output-temps (unique-identifiers-for-list types))
(emit:declarations types output-temps #f port level)
@@ -2495,8 +2557,8 @@
", "))
output-temps)
-(define (emit:struct-ref type exp field version port level)
- (define input-temp (single-temp (emit-glsl exp version port level)))
+(define (emit:struct-ref type exp field stage version port level)
+ (define input-temp (single-temp (emit-glsl exp stage version port level)))
(define output-temp (unique-identifier))
(indent level port)
(format port "~a ~a = ~a.~a;\n"
@@ -2506,9 +2568,9 @@
field)
(list output-temp))
-(define (emit:array-ref type array-exp index-exp version port level)
- (define array-temp (single-temp (emit-glsl array-exp version port level)))
- (define index-temp (single-temp (emit-glsl index-exp version port level)))
+(define (emit:array-ref type array-exp index-exp stage version port level)
+ (define array-temp (single-temp (emit-glsl array-exp stage version port level)))
+ (define index-temp (single-temp (emit-glsl index-exp stage version port level)))
(define output-temp (unique-identifier))
(indent level port)
(format port "~a ~a = ~a[~a];\n"
@@ -2518,7 +2580,7 @@
index-temp)
(list output-temp))
-(define (emit:top-level bindings body version port level)
+(define (emit:top-level bindings body stage version port level)
(for-each (match-lambda
(((? top-level-qualifier? qualifier) type-desc name)
(format port "~a ~a ~a;\n"
@@ -2526,10 +2588,10 @@
(type-descriptor->glsl type-desc)
name))
(('function name ('t (type) ('lambda params body)))
- (emit:function name type params body version port level)))
+ (emit:function name type params body stage version port level)))
bindings)
(display "void main() {\n" port)
- (emit-glsl body version port (+ level 1))
+ (emit-glsl body stage version port (+ level 1))
(display "}\n" port))
(define %built-in-output-map
@@ -2538,51 +2600,55 @@
(vertex:clip-distance . gl_ClipDistance)
(fragment:depth . gl_FragDepth)))
-(define (emit:outputs names exps version port level)
+(define (emit:outputs names exps stage version port level)
(define (output-name name)
(or (assq-ref %built-in-output-map name) name))
- (for-each (lambda (name exp)
- (match (emit-glsl exp version port level)
- ((temp)
- (indent level port)
- (format port "~a = ~a;\n"
- (output-name name)
- temp))))
- names exps)
+ (if (and (eq? stage 'fragment) (null? names))
+ (begin
+ (indent level port)
+ (format port "discard;\n"))
+ (for-each (lambda (name exp)
+ (match (emit-glsl exp stage version port level)
+ ((temp)
+ (indent level port)
+ (format port "~a = ~a;\n"
+ (output-name name)
+ temp))))
+ names exps))
'(#f))
-(define* (emit-glsl exp version port #:optional (level 0))
+(define* (emit-glsl exp stage version port #:optional (level 0))
(match exp
(('t _ (? exact-integer? n))
- (emit:int n version port level))
+ (emit:int n stage version port level))
(('t _ (? float? n))
- (emit:float n version port level))
+ (emit:float n stage version port level))
(('t _ (? boolean? b))
- (emit:boolean b version port level))
+ (emit:boolean b stage version port level))
(('t _ (? symbol? var))
(list var))
(('t _ ('if predicate consequent alternate))
- (emit:if predicate consequent alternate version port level))
+ (emit:if predicate consequent alternate stage version port level))
(('t _ ('values exps ...))
- (emit:values exps version port level))
+ (emit:values exps stage version port level))
(('t types ('let ((names exps) ...) body))
- (emit:let types names exps body version port level))
+ (emit:let types names exps body stage version port level))
(('t (type) ('primcall (? binary-operator? op) a b))
- (emit:binary-operator type op a b version port level))
+ (emit:binary-operator type op a b stage version port level))
(('t (type) ('primcall (? unary-operator? op) a))
- (emit:unary-operator type op a version port level))
+ (emit:unary-operator type op a stage version port level))
(('t (type) ('primcall op args ...))
- (emit:primcall type op args version port level))
+ (emit:primcall type op args stage version port level))
(('t types ('call operator args ...))
- (emit:call types operator args version port level))
+ (emit:call types operator args stage version port level))
(('t (type) ('struct-ref exp field))
- (emit:struct-ref type exp field version port level))
+ (emit:struct-ref type exp field stage version port level))
(('t (type) ('array-ref array-exp index-exp))
- (emit:array-ref type array-exp index-exp version port level))
+ (emit:array-ref type array-exp index-exp stage version port level))
(('t _ ('outputs (names exps) ...))
- (emit:outputs names exps version port level))
+ (emit:outputs names exps stage version port level))
(('t _ ('top-level (bindings ...) body))
- (emit:top-level bindings body version port level))))
+ (emit:top-level bindings body stage version port level))))
;;;
@@ -2657,7 +2723,7 @@
(let* ((propagated (propagate-constants expanded (empty-env)))
(hoisted (hoist-functions* propagated))
(inferred (infer-types hoisted stage))
- (resolved (resolve-overloads inferred)))
+ (resolved (resolve-overloads inferred stage)))
(values resolved global-map (unique-identifier-counter))))))
(define (specs->globals specs)
@@ -2666,6 +2732,7 @@
(make-seagull-global qualifier type-desc name)))
specs))
+(use-modules (ice-9 pretty-print))
;; Using syntax-case allows us to compile shaders to their fully typed
;; intermediate form at compile time, leaving only GLSL emission for
;; runtime.
@@ -2685,6 +2752,7 @@
#:inputs inputs
#:outputs outputs
#:uniforms uniforms))
+ (pretty-print compiled)
(with-syntax ((inputs (datum->syntax x inputs))
(outputs (datum->syntax x outputs))
(uniforms (datum->syntax x uniforms))
@@ -2784,8 +2852,17 @@
(fragment-uniform-map (map-globals (seagull-module-uniforms fragment)
fragment-global-map))
;; Give new names to the fragment uniforms so that the names
- ;; do not clash with vertex globals.
- (fragment-uniform-alpha-map (alpha-rename fragment-uniform-map))
+ ;; do not clash with vertex globals and also that any
+ ;; uniforms in the vertex shader have the *same* name in the
+ ;; fragment shader.
+ (fragment-uniform-alpha-map
+ (map (match-lambda
+ ((original-name . alpha-name)
+ (cons alpha-name
+ (or (assq-ref vertex-uniform-alpha-map
+ (assq-ref vertex-uniform-map original-name))
+ (unique-identifier)))))
+ fragment-uniform-map))
;; This one is a little messy but what's happening is that
;; the GLSL name for each fragment output is mapped to the
;; respective renamed input. Vertex shader output names must
@@ -2843,11 +2920,11 @@
(define vertex-glsl
(call-with-output-string
(lambda (port)
- (emit-glsl vertex* version port))))
+ (emit-glsl vertex* 'fragment version port))))
(define fragment-glsl
(call-with-output-string
(lambda (port)
- (emit-glsl fragment* version port))))
+ (emit-glsl fragment* 'fragment version port))))
(display vertex-glsl)
(newline)
(display fragment-glsl)