summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Thompson <dthompson2@worcester.edu>2023-02-06 08:10:33 -0500
committerDavid Thompson <dthompson2@worcester.edu>2023-06-08 08:14:41 -0400
commit19ee392840e3e5d840a03924d89f410b039fdf50 (patch)
treeb51ac1d8fc09f4bc5912f0b8ab0ed96d3c845e46
parent93cb7c6ab2df3c33c41cc00df0bcb7727cce11e3 (diff)
Add struct types.
-rw-r--r--chickadee/graphics/seagull.scm397
1 files changed, 283 insertions, 114 deletions
diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm
index bb7008f..ca11ec3 100644
--- a/chickadee/graphics/seagull.scm
+++ b/chickadee/graphics/seagull.scm
@@ -292,6 +292,18 @@
(define (expand:values exps stage env)
`(values ,@(expand:list exps stage env)))
+(define (expand:-> exp members stage env)
+ (define exp* (expand exp stage env))
+ (match members
+ ((member . rest)
+ (let loop ((members rest)
+ (prev `(struct-ref ,exp* ,member)))
+ (match members
+ (() prev)
+ ((next . rest)
+ (loop `(struct-ref ,prev ,next)
+ rest)))))))
+
;; Arithmetic operators, in true Lisp fashion, can accept many
;; arguments. + and * accept 0 or more. - and / accept one or more.
;; The expansion pass transforms all such expressions into binary
@@ -437,6 +449,8 @@
body)
(expand:top-level qualifiers types names body stage env))
;; Macros:
+ (('-> exp (? symbol? members) ..1)
+ (expand:-> exp members stage env))
(('let* (bindings ...) body)
(expand:let* bindings body stage env))
(('+ args ...)
@@ -526,6 +540,9 @@
(propagate-constants arg env))
args)))
+(define (propagate:struct-ref exp member env)
+ `(struct-ref ,(propagate-constants exp env) ,member))
+
;; The division of two integers can result in a rational, non-integer,
;; such as 1/2. This isn't how integer division works in GLSL, so we
;; need to round the result to an integer.
@@ -576,6 +593,8 @@
(propagate:primcall operator args env))
(('call operator args ...)
(propagate:call operator args env))
+ (('struct-ref exp member)
+ (propagate:struct-ref exp member env))
(('outputs (names exps) ...)
(propagate:outputs names exps env))
(('top-level inputs body)
@@ -637,6 +656,8 @@
((or ('primcall _ args ...)
('call args ...))
(check-free-variables-in-list args bound-vars top-level-vars))
+ (('struct-ref exp member)
+ (check-free-variables exp bound-vars top-level-vars))
(('outputs (names exps) ...)
(check-free-variables-in-list exps bound-vars top-level-vars))
(('top-level ((_ _ names) ...) body)
@@ -695,8 +716,7 @@
(define (hoist:lambda params body)
(define-values (body* body-env)
(hoist-functions body))
- (values `(lambda ,params ,body*)
- body-env))
+ (values `(lambda ,params ,body*) body-env))
(define (hoist:values exps)
(define-values (exps* exp-env)
@@ -705,13 +725,15 @@
(define (hoist:primcall operator args)
(define-values (args* args-env) (hoist:list args))
- (values `(primcall ,operator ,@args*)
- args-env))
+ (values `(primcall ,operator ,@args*) args-env))
(define (hoist:call args)
(define-values (args* args-env) (hoist:list args))
- (values `(call ,@args*)
- args-env))
+ (values `(call ,@args*) args-env))
+
+(define (hoist:struct-ref exp member)
+ (define-values (exp* exp-env) (hoist-functions exp))
+ (values `(struct-ref ,exp* ,member) exp-env))
(define (hoist:top-level inputs body)
(define-values (body* body-env)
@@ -744,6 +766,8 @@
(hoist:primcall operator args))
(('call args ...)
(hoist:call args))
+ (('struct-ref exp member)
+ (hoist:struct-ref exp member))
(('outputs (names exps) ...)
(hoist:outputs names exps))
(('top-level inputs body)
@@ -794,27 +818,36 @@
(match type
(('primitive name) name)))
-(define int-type (primitive-type 'int))
-(define float-type (primitive-type 'float))
-(define bool-type (primitive-type 'bool))
-(define vec2-type (primitive-type 'vec2))
-(define vec3-type (primitive-type 'vec3))
-(define vec4-type (primitive-type 'vec4))
-(define mat3-type (primitive-type 'mat3))
-(define mat4-type (primitive-type 'mat4))
-(define sampler-2d-type (primitive-type 'sampler-2d))
+;; Outputs type:
+(define outputs-type '(outputs))
-(define (type-name->type name)
- (case name
- ((bool) bool-type)
- ((int) int-type)
- ((float) float-type)
- ((vec2) vec2-type)
- ((vec3) vec3-type)
- ((vec4) vec4-type)
- ((mat3) mat3-type)
- ((mat4) mat4-type)
- ((sampler-2d) sampler-2d-type)))
+(define (outputs-type? obj)
+ (match obj
+ (('outputs) #t)
+ (_ #f)))
+
+;; Struct type:
+(define (struct-type name members)
+ `(struct ,name ,members))
+
+(define (struct-type? obj)
+ (match obj
+ (('struct _ _) #t)
+ (_ #f)))
+
+(define (struct-type-name type)
+ (match type
+ (('struct name _) name)))
+
+(define (struct-type-members type)
+ (match type
+ (('struct _ members) members)))
+
+(define (struct-type-ref type member)
+ (assq-ref (struct-type-members type) member))
+
+(define-syntax-rule (define-struct-type (var-name name) (types names) ...)
+ (define var-name (struct-type 'name (list (cons 'names types) ...))))
;; Type variables:
(define unique-type-variable-counter (make-parameter 0))
@@ -860,43 +893,50 @@
(('-> _ returns) returns)))
;; For all types:
-(define (for-all-type quantifiers type predicate)
- `(for-all ,quantifiers ,type ,predicate))
+(define (for-all-type quantifiers type)
+ `(for-all ,quantifiers ,type))
(define (for-all-type? obj)
(match obj
- (('for-all _ _ _) #t)
+ (('for-all _ _) #t)
(_ #f)))
(define (for-all-type-quantifiers type)
(match type
- (('for-all q _ _) q)))
+ (('for-all q _) q)))
(define (for-all-type-ref type)
(match type
- (('for-all _ t _) t)))
-
-(define (for-all-type-predicate type)
- (match type
- (('for-all _ _ p) p)))
+ (('for-all _ t) t)))
-;; Outputs type:
-(define outputs-type '(outputs))
+;; Qualified types:
+(define (qualified-type type pred)
+ `(qualified ,type ,pred))
-(define (outputs-type? obj)
+(define (qualified-type? obj)
(match obj
- (('outputs) #t)
+ (('qualified _ _) #t)
(_ #f)))
+(define (qualified-type-ref type)
+ (match type
+ (('qualified type _) type)))
+
+(define (qualified-type-predicate type)
+ (match type
+ (('qualified _ pred) pred)))
+
(define (type? obj)
(or (primitive-type? obj)
(type-variable? obj)
(function-type? obj)
+ (struct-type? obj)
(outputs-type? obj)))
(define (apply-substitution-to-type type from to)
(cond
((or (primitive-type? type)
+ (struct-type? type)
(outputs-type? type))
type)
((type-variable? type)
@@ -911,6 +951,11 @@
(function-type-returns type))))
((for-all-type? type)
type)
+ ;; ((qualified-type? type)
+ ;; (qualified-type (apply-substitution-to-type
+ ;; (qualified-type-ref type) from to)
+ ;; (apply-substitution-to-predicate
+ ;; (qualified-type-predicate type) from to)))
(else (error "invalid type" type))))
(define (apply-substitutions-to-type type subs)
@@ -1072,6 +1117,9 @@
(predicate:substitute from to)))
subs)))
+(define (predicate:struct-has? struct member var)
+ `(struct-has? ,struct ,member ,var))
+
(define (compose-predicates a b)
(cond
((and (eq? a #t) (eq? b #t))
@@ -1110,7 +1158,11 @@
preds)))
(('substitute a b)
`(substitute ,(apply-substitution-to-type a from to)
- ,(apply-substitution-to-type b from to)))))
+ ,(apply-substitution-to-type b from to)))
+ (('struct-has? struct member var)
+ `(struct-has? ,(apply-substitution-to-type struct from to)
+ ,member
+ ,var))))
(define (apply-substitutions-to-predicate pred subs)
(env-fold (lambda (from to pred*)
@@ -1271,7 +1323,11 @@
;; Substitution always succeeds and returns a substitution to be
;; carried forward in the inference process.
(('substitute a b)
- (values #t (list (cons a b))))))
+ (values #t (list (cons a b))))
+ (('struct-has? struct member var)
+ (if (struct-type? struct)
+ (values #t (list (cons var (struct-type-ref struct member))))
+ (values pred '())))))
(define (eval-predicate* pred subs)
(define-values (new-pred pred-subs)
@@ -1279,7 +1335,7 @@
(apply-substitutions-to-predicate pred subs)))
;; TODO: Get information about *why* the predicate failed.
(unless new-pred
- (error "predicate failure"))
+ (error "predicate failure" pred))
;; Recursively evaluate the predicate, applying the substitutions
;; generated by the last evaluation, until it cannot be simplified
;; any further.
@@ -1289,7 +1345,9 @@
(define* (free-variables-in-type type)
(cond
- ((primitive-type? type) '())
+ ((or (primitive-type? type)
+ (struct-type? type))
+ '())
((type-variable? type) (list type))
((function-type? type)
(let ((params (function-type-parameters type)))
@@ -1349,7 +1407,10 @@
preds))
(('substitute a b)
(append (free-variables-in-type a)
- (free-variables-in-type b)))))
+ (free-variables-in-type b)))
+ (('struct-has? struct member var)
+ (append (free-variables-in-type struct)
+ (free-variables-in-type var)))))
;; Quantified variables:
;; - Unused parameters
@@ -1363,7 +1424,7 @@
(free-variables-in-env env))
(() type)
((quantifiers ...)
- (for-all-type quantifiers type pred)))
+ (for-all-type quantifiers (qualified-type type pred))))
type))
(define (instantiate for-all)
@@ -1372,9 +1433,16 @@
(extend-env var (fresh-type-variable) env))
(empty-env)
(for-all-type-quantifiers for-all)))
+ (define type (for-all-type-ref for-all))
(values
- (apply-substitutions-to-type (for-all-type-ref for-all) subs)
- (apply-substitutions-to-predicate (for-all-type-predicate for-all) subs)))
+ (apply-substitutions-to-type (if (qualified-type? type)
+ (qualified-type-ref type)
+ type)
+ subs)
+ (if (qualified-type? type)
+ (apply-substitutions-to-predicate (qualified-type-predicate type)
+ subs)
+ #t)))
(define (maybe-instantiate types)
(define types+preds
@@ -1391,6 +1459,11 @@
'()
(error "primitive type mismatch" a b)))
+(define (unify:structs a b)
+ (if (equal? a b)
+ '()
+ (error "struct type mismatch" a b)))
+
(define (unify:variable a b)
(cond
((eq? a b)
@@ -1422,6 +1495,8 @@
(match (list a b)
(((? primitive-type? a) (? primitive-type? b))
(unify:primitives a b))
+ (((? struct-type? a) (? struct-type? b))
+ (unify:structs a b))
((or ((? type-variable? a) b)
(b (? type-variable? a)))
(unify:variable a b))
@@ -1429,6 +1504,8 @@
(unify:functions a b))
(((? outputs-type?) (? outputs-type?))
'())
+ (((? type?) (? type?))
+ (error "type mismatch" a b))
((() ())
'())
(((a rest-a ...) (b rest-b ...))
@@ -1527,7 +1604,7 @@
(eval-predicate* body-pred body-subs))
(values (texp (list (generalize
(function-type (apply-substitutions-to-types param-types
- body-subs)
+ subs)
(texp-types body*))
pred env))
`(lambda ,params ,body*))
@@ -1604,6 +1681,17 @@
combined-subs
pred))
+(define (infer:struct-ref exp member env)
+ (define-values (exp* exp-subs exp-pred)
+ (infer-exp exp env))
+ (define exp-type (single-type exp*))
+ (define tvar (fresh-type-variable))
+ (values (texp (list tvar)
+ `(struct-ref ,exp* ,member))
+ exp-subs
+ (compose-predicates exp-pred
+ (predicate:struct-has? exp-type member tvar))))
+
(define (infer:let names exps body env)
(define-values (exps* exp-subs exp-pred)
(infer:list exps env))
@@ -1643,31 +1731,37 @@
;; Eval predicate.
(define-values (pred combined-subs)
(eval-predicate* exp-pred (compose-substitutions exp-subs unify-subs)))
- (values (texp (map single-type exps*)
+ (values (texp (list outputs-type)
`(outputs ,@(map list names exps*)))
combined-subs
pred))
(define (infer:top-level bindings body env)
- (define (infer-bindings bindings texps subs pred)
+ (define (infer-bindings bindings texps subs pred env)
(match bindings
(()
- (values (reverse texps) subs pred))
- ((('function _ exp) . rest)
+ (values (reverse texps) subs pred env))
+ ((('function name exp) . rest)
(define-values (texp subs* pred*)
(infer-exp exp env))
(define-values (new-pred combined-subs)
(eval-predicate* (compose-predicates pred pred*)
(compose-substitutions subs subs*)))
+ (define env*
+ (apply-substitutions-to-env (extend-env name (texp-types texp) env)
+ combined-subs))
(infer-bindings rest
(cons texp texps)
combined-subs
- new-pred))
- (((_ type-name _) . rest)
+ new-pred
+ env*))
+ (((_ type-name name) . rest)
+ (define types (list (type-name->type type-name)))
(infer-bindings rest
- (cons (list (type-name->type type-name)) texps)
+ (cons types texps)
subs
- pred))))
+ pred
+ (extend-env name types env)))))
(define qualifiers (map first bindings))
(define names
(map (match-lambda
@@ -1679,19 +1773,8 @@
(((? top-level-qualifier?) type-name _) type-name)
(_ #f))
bindings))
- (define-values (exps exp-subs exp-pred)
- (infer-bindings bindings '() '() #t))
- (define exp-types
- (map (lambda (x)
- (if (texp? x)
- (texp-types x)
- x))
- exps))
- (define env*
- (fold extend-env
- (apply-substitutions-to-env env exp-subs)
- names
- exp-types))
+ (define-values (exps exp-subs exp-pred env*)
+ (infer-bindings bindings '() '() #t env))
(define-values (body* body-subs body-pred)
(infer-exp body env*))
(define-values (pred combined-subs)
@@ -1731,25 +1814,62 @@
(infer:primitive-call operator args env))
(('call operator args ...)
(infer:call operator args env))
+ (('struct-ref exp member)
+ (infer:struct-ref exp member env))
(('outputs (names exps) ...)
(infer:outputs names exps env))
(('top-level bindings body)
(infer:top-level bindings body env))
(_ (error "unknown form" exp))))
+;; Built-in types:
+(define int-type (primitive-type 'int))
+(define float-type (primitive-type 'float))
+(define bool-type (primitive-type 'bool))
+(define-struct-type (vec2-type vec2)
+ (float-type x)
+ (float-type y))
+(define-struct-type (vec3-type vec3)
+ (float-type x)
+ (float-type y)
+ (float-type z))
+(define-struct-type (vec4-type vec4)
+ (float-type x)
+ (float-type y)
+ (float-type z)
+ (float-type w))
+;; TODO: Matrices are technically array types in GLSL, but we are
+;; choosing to represent them opaquely for now to keep things simple.
+(define mat3-type (primitive-type 'mat3))
+(define mat4-type (primitive-type 'mat4))
+(define sampler-2d-type (primitive-type 'sampler-2d))
+
+(define (type-name->type name)
+ (case name
+ ((bool) bool-type)
+ ((int) int-type)
+ ((float) float-type)
+ ((vec2) vec2-type)
+ ((vec3) vec3-type)
+ ((vec4) vec4-type)
+ ((mat3) mat3-type)
+ ((mat4) mat4-type)
+ ((sampler-2d) sampler-2d-type)))
+
(define add/sub-type
(let ((a (fresh-type-variable)))
(list (for-all-type
(list a)
- (function-type (list a a) (list a))
- (predicate:or
- (predicate:= a int-type)
- (predicate:= a float-type)
- (predicate:= a vec2-type)
- (predicate:= a vec3-type)
- (predicate:= a vec4-type)
- (predicate:= a mat3-type)
- (predicate:= a mat4-type))))))
+ (qualified-type
+ (function-type (list a a) (list a))
+ (predicate:or
+ (predicate:= a int-type)
+ (predicate:= a float-type)
+ (predicate:= a vec2-type)
+ (predicate:= a vec3-type)
+ (predicate:= a vec4-type)
+ (predicate:= a mat3-type)
+ (predicate:= a mat4-type)))))))
(define-syntax-rule (a+b->c (ta tb tc) ...)
(let ((a (fresh-type-variable))
@@ -1757,12 +1877,13 @@
(c (fresh-type-variable)))
(list (for-all-type
(list a b c)
- (function-type (list a b) (list c))
- (predicate:or
- (predicate:and (predicate:= a ta)
- (predicate:= b tb)
- (predicate:substitute c tc))
- ...)))))
+ (qualified-type
+ (function-type (list a b) (list c))
+ (predicate:or
+ (predicate:and (predicate:= a ta)
+ (predicate:= b tb)
+ (predicate:substitute c tc))
+ ...))))))
(define mul-type
(a+b->c (int-type int-type int-type)
@@ -1777,11 +1898,11 @@
(vec4-type float-type vec4-type)
(float-type vec4-type vec4-type)
(mat3-type mat3-type mat3-type)
- (mat3-type float-type mat3-type)
- (float-type mat3-type mat3-type)
+ (mat3-type vec3-type mat3-type)
+ (vec3-type mat3-type mat3-type)
(mat4-type mat4-type mat4-type)
- (mat4-type float-type mat4-type)
- (float-type mat4-type mat4-type)))
+ (mat4-type vec4-type vec4-type)
+ (vec4-type mat4-type vec4-type)))
(define div-type
(a+b->c (int-type int-type int-type)
@@ -1805,10 +1926,11 @@
(let ((a (fresh-type-variable)))
(list (for-all-type
(list a)
- (function-type (list a a) (list bool-type))
- (predicate:or
- (predicate:= a int-type)
- (predicate:= a float-type))))))
+ (qualified-type
+ (function-type (list a a) (list bool-type))
+ (predicate:or
+ (predicate:= a int-type)
+ (predicate:= a float-type)))))))
(define not-type
(list (function-type (list bool-type) (list bool-type))))
@@ -1829,28 +1951,31 @@
(let ((a (fresh-type-variable)))
(list (for-all-type
(list a)
- (function-type (list a) (list a))
- (predicate:or
- (predicate:= a int-type)
- (predicate:= a float-type))))))
+ (qualified-type
+ (function-type (list a) (list a))
+ (predicate:or
+ (predicate:= a int-type)
+ (predicate:= a float-type)))))))
(define sqrt-type
(let ((a (fresh-type-variable)))
(list (for-all-type
(list a)
- (function-type (list a) (list a))
- (predicate:or
- (predicate:= a int-type)
- (predicate:= a float-type))))))
+ (qualified-type
+ (function-type (list a) (list a))
+ (predicate:or
+ (predicate:= a int-type)
+ (predicate:= a float-type)))))))
(define min/max-type
(let ((a (fresh-type-variable)))
(list (for-all-type
(list a)
- (function-type (list a a) (list a))
- (predicate:or
- (predicate:= a int-type)
- (predicate:= a float-type))))))
+ (qualified-type
+ (function-type (list a a) (list a))
+ (predicate:or
+ (predicate:= a int-type)
+ (predicate:= a float-type)))))))
(define trigonometry-type
(list (function-type (list float-type) (list float-type))))
@@ -1859,10 +1984,11 @@
(let ((a (fresh-type-variable)))
(list (for-all-type
(list a)
- (function-type (list a a) (list a))
- (predicate:or
- (predicate:= a int-type)
- (predicate:= a float-type))))))
+ (qualified-type
+ (function-type (list a a) (list a))
+ (predicate:or
+ (predicate:= a int-type)
+ (predicate:= a float-type)))))))
(define texture-2d-ref-type
(list (function-type (list sampler-2d-type vec2-type)
@@ -1928,17 +2054,28 @@
;; Compute all the valid permutations of substitutions that a
;; predicate could produce.
-(define (possible-substitutions pred)
+(define (possible-substitutions pred structs)
(match pred
(#t '())
(('substitute from to)
(list (list (cons from to))))
(('= a b)
(list (list (cons a b))))
+ (('struct-has? struct-var member member-var)
+ (filter-map (lambda (struct)
+ (let ((member-type (struct-type-ref struct member)))
+ (and member-type
+ (list (cons struct-var struct)
+ (cons member-var member-type)))))
+ structs))
(('or preds ...)
- (concatenate (map possible-substitutions preds)))
+ (concatenate (map (lambda (pred)
+ (possible-substitutions pred structs))
+ preds)))
(((or 'and 'list) preds ...)
- (let loop ((in (map possible-substitutions preds)))
+ (let loop ((in (map (lambda (pred)
+ (possible-substitutions pred structs))
+ preds)))
(match in
(() '())
((options . rest)
@@ -1952,7 +2089,19 @@
rest-options)))
options)))))))
+(define (find-structs exp)
+ (match exp
+ ((? struct-type?)
+ (list exp))
+ ((exps ...)
+ (append-map find-structs exps))
+ (_ '())))
+
(define (resolve-overloads program)
+ ;; Find all of the struct types used in the program. They will be
+ ;; used to generate overloaded functions that take one or more
+ ;; structs as arguments.
+ (define structs (delete-duplicates (find-structs program)))
(match program
(('t types ('top-level bindings body))
(define bindings*
@@ -1960,7 +2109,8 @@
(match bindings
(() '())
((('function name ('t ((? for-all-type? type)) func)) . rest)
- (define func-type (for-all-type-ref type))
+ (define qtype (for-all-type-ref type))
+ (define func-type (qualified-type-ref qtype))
(append (map (lambda (subs)
(define type*
(apply-substitutions-to-type func-type subs))
@@ -1968,7 +2118,9 @@
(apply-substitutions-to-exp func subs))
`(function ,name (t (,type*) ,func*)))
(delete-duplicates
- (possible-substitutions (for-all-type-predicate type))))
+ (possible-substitutions
+ (qualified-type-predicate qtype)
+ structs)))
(loop rest)))
((binding . rest)
(cons binding (loop rest))))))
@@ -1983,7 +2135,9 @@
;; Transform a fully typed Seagull program into a string of GLSL code.
(define (type-name type)
- (primitive-type-name type))
+ (if (struct-type? type)
+ (struct-type-name type)
+ (primitive-type-name type)))
(define (single-temp temps)
(match temps
@@ -2157,6 +2311,19 @@
", "))
output-temps)
+(define (emit:struct-ref type exp member version port level)
+ (define input-temp
+ (match (emit-glsl exp version port level)
+ ((temp) temp)))
+ (define output-temp (unique-identifier))
+ (indent level port)
+ (format port "~a ~a = ~a.~a;\n"
+ (type-name type)
+ output-temp
+ input-temp
+ member)
+ (list output-temp))
+
(define %type-name-map
'((sampler-2d . sampler2D)))
@@ -2215,6 +2382,8 @@
(emit:primcall type op args version port level))
(('t types ('call operator args ...))
(emit:call types operator args version port level))
+ (('t (type) ('struct-ref exp member))
+ (emit:struct-ref type exp member version port level))
(('t _ ('outputs (names exps) ...))
(emit:outputs names exps version port level))
(('t _ ('top-level (bindings ...) body))