From 19ee392840e3e5d840a03924d89f410b039fdf50 Mon Sep 17 00:00:00 2001 From: David Thompson Date: Mon, 6 Feb 2023 08:10:33 -0500 Subject: Add struct types. --- chickadee/graphics/seagull.scm | 397 +++++++++++++++++++++++++++++------------ 1 file changed, 283 insertions(+), 114 deletions(-) diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm index bb7008f..ca11ec3 100644 --- a/chickadee/graphics/seagull.scm +++ b/chickadee/graphics/seagull.scm @@ -292,6 +292,18 @@ (define (expand:values exps stage env) `(values ,@(expand:list exps stage env))) +(define (expand:-> exp members stage env) + (define exp* (expand exp stage env)) + (match members + ((member . rest) + (let loop ((members rest) + (prev `(struct-ref ,exp* ,member))) + (match members + (() prev) + ((next . rest) + (loop `(struct-ref ,prev ,next) + rest))))))) + ;; Arithmetic operators, in true Lisp fashion, can accept many ;; arguments. + and * accept 0 or more. - and / accept one or more. ;; The expansion pass transforms all such expressions into binary @@ -437,6 +449,8 @@ body) (expand:top-level qualifiers types names body stage env)) ;; Macros: + (('-> exp (? symbol? members) ..1) + (expand:-> exp members stage env)) (('let* (bindings ...) body) (expand:let* bindings body stage env)) (('+ args ...) @@ -526,6 +540,9 @@ (propagate-constants arg env)) args))) +(define (propagate:struct-ref exp member env) + `(struct-ref ,(propagate-constants exp env) ,member)) + ;; The division of two integers can result in a rational, non-integer, ;; such as 1/2. This isn't how integer division works in GLSL, so we ;; need to round the result to an integer. @@ -576,6 +593,8 @@ (propagate:primcall operator args env)) (('call operator args ...) (propagate:call operator args env)) + (('struct-ref exp member) + (propagate:struct-ref exp member env)) (('outputs (names exps) ...) (propagate:outputs names exps env)) (('top-level inputs body) @@ -637,6 +656,8 @@ ((or ('primcall _ args ...) ('call args ...)) (check-free-variables-in-list args bound-vars top-level-vars)) + (('struct-ref exp member) + (check-free-variables exp bound-vars top-level-vars)) (('outputs (names exps) ...) (check-free-variables-in-list exps bound-vars top-level-vars)) (('top-level ((_ _ names) ...) body) @@ -695,8 +716,7 @@ (define (hoist:lambda params body) (define-values (body* body-env) (hoist-functions body)) - (values `(lambda ,params ,body*) - body-env)) + (values `(lambda ,params ,body*) body-env)) (define (hoist:values exps) (define-values (exps* exp-env) @@ -705,13 +725,15 @@ (define (hoist:primcall operator args) (define-values (args* args-env) (hoist:list args)) - (values `(primcall ,operator ,@args*) - args-env)) + (values `(primcall ,operator ,@args*) args-env)) (define (hoist:call args) (define-values (args* args-env) (hoist:list args)) - (values `(call ,@args*) - args-env)) + (values `(call ,@args*) args-env)) + +(define (hoist:struct-ref exp member) + (define-values (exp* exp-env) (hoist-functions exp)) + (values `(struct-ref ,exp* ,member) exp-env)) (define (hoist:top-level inputs body) (define-values (body* body-env) @@ -744,6 +766,8 @@ (hoist:primcall operator args)) (('call args ...) (hoist:call args)) + (('struct-ref exp member) + (hoist:struct-ref exp member)) (('outputs (names exps) ...) (hoist:outputs names exps)) (('top-level inputs body) @@ -794,27 +818,36 @@ (match type (('primitive name) name))) -(define int-type (primitive-type 'int)) -(define float-type (primitive-type 'float)) -(define bool-type (primitive-type 'bool)) -(define vec2-type (primitive-type 'vec2)) -(define vec3-type (primitive-type 'vec3)) -(define vec4-type (primitive-type 'vec4)) -(define mat3-type (primitive-type 'mat3)) -(define mat4-type (primitive-type 'mat4)) -(define sampler-2d-type (primitive-type 'sampler-2d)) +;; Outputs type: +(define outputs-type '(outputs)) -(define (type-name->type name) - (case name - ((bool) bool-type) - ((int) int-type) - ((float) float-type) - ((vec2) vec2-type) - ((vec3) vec3-type) - ((vec4) vec4-type) - ((mat3) mat3-type) - ((mat4) mat4-type) - ((sampler-2d) sampler-2d-type))) +(define (outputs-type? obj) + (match obj + (('outputs) #t) + (_ #f))) + +;; Struct type: +(define (struct-type name members) + `(struct ,name ,members)) + +(define (struct-type? obj) + (match obj + (('struct _ _) #t) + (_ #f))) + +(define (struct-type-name type) + (match type + (('struct name _) name))) + +(define (struct-type-members type) + (match type + (('struct _ members) members))) + +(define (struct-type-ref type member) + (assq-ref (struct-type-members type) member)) + +(define-syntax-rule (define-struct-type (var-name name) (types names) ...) + (define var-name (struct-type 'name (list (cons 'names types) ...)))) ;; Type variables: (define unique-type-variable-counter (make-parameter 0)) @@ -860,43 +893,50 @@ (('-> _ returns) returns))) ;; For all types: -(define (for-all-type quantifiers type predicate) - `(for-all ,quantifiers ,type ,predicate)) +(define (for-all-type quantifiers type) + `(for-all ,quantifiers ,type)) (define (for-all-type? obj) (match obj - (('for-all _ _ _) #t) + (('for-all _ _) #t) (_ #f))) (define (for-all-type-quantifiers type) (match type - (('for-all q _ _) q))) + (('for-all q _) q))) (define (for-all-type-ref type) (match type - (('for-all _ t _) t))) - -(define (for-all-type-predicate type) - (match type - (('for-all _ _ p) p))) + (('for-all _ t) t))) -;; Outputs type: -(define outputs-type '(outputs)) +;; Qualified types: +(define (qualified-type type pred) + `(qualified ,type ,pred)) -(define (outputs-type? obj) +(define (qualified-type? obj) (match obj - (('outputs) #t) + (('qualified _ _) #t) (_ #f))) +(define (qualified-type-ref type) + (match type + (('qualified type _) type))) + +(define (qualified-type-predicate type) + (match type + (('qualified _ pred) pred))) + (define (type? obj) (or (primitive-type? obj) (type-variable? obj) (function-type? obj) + (struct-type? obj) (outputs-type? obj))) (define (apply-substitution-to-type type from to) (cond ((or (primitive-type? type) + (struct-type? type) (outputs-type? type)) type) ((type-variable? type) @@ -911,6 +951,11 @@ (function-type-returns type)))) ((for-all-type? type) type) + ;; ((qualified-type? type) + ;; (qualified-type (apply-substitution-to-type + ;; (qualified-type-ref type) from to) + ;; (apply-substitution-to-predicate + ;; (qualified-type-predicate type) from to))) (else (error "invalid type" type)))) (define (apply-substitutions-to-type type subs) @@ -1072,6 +1117,9 @@ (predicate:substitute from to))) subs))) +(define (predicate:struct-has? struct member var) + `(struct-has? ,struct ,member ,var)) + (define (compose-predicates a b) (cond ((and (eq? a #t) (eq? b #t)) @@ -1110,7 +1158,11 @@ preds))) (('substitute a b) `(substitute ,(apply-substitution-to-type a from to) - ,(apply-substitution-to-type b from to))))) + ,(apply-substitution-to-type b from to))) + (('struct-has? struct member var) + `(struct-has? ,(apply-substitution-to-type struct from to) + ,member + ,var)))) (define (apply-substitutions-to-predicate pred subs) (env-fold (lambda (from to pred*) @@ -1271,7 +1323,11 @@ ;; Substitution always succeeds and returns a substitution to be ;; carried forward in the inference process. (('substitute a b) - (values #t (list (cons a b)))))) + (values #t (list (cons a b)))) + (('struct-has? struct member var) + (if (struct-type? struct) + (values #t (list (cons var (struct-type-ref struct member)))) + (values pred '()))))) (define (eval-predicate* pred subs) (define-values (new-pred pred-subs) @@ -1279,7 +1335,7 @@ (apply-substitutions-to-predicate pred subs))) ;; TODO: Get information about *why* the predicate failed. (unless new-pred - (error "predicate failure")) + (error "predicate failure" pred)) ;; Recursively evaluate the predicate, applying the substitutions ;; generated by the last evaluation, until it cannot be simplified ;; any further. @@ -1289,7 +1345,9 @@ (define* (free-variables-in-type type) (cond - ((primitive-type? type) '()) + ((or (primitive-type? type) + (struct-type? type)) + '()) ((type-variable? type) (list type)) ((function-type? type) (let ((params (function-type-parameters type))) @@ -1349,7 +1407,10 @@ preds)) (('substitute a b) (append (free-variables-in-type a) - (free-variables-in-type b))))) + (free-variables-in-type b))) + (('struct-has? struct member var) + (append (free-variables-in-type struct) + (free-variables-in-type var))))) ;; Quantified variables: ;; - Unused parameters @@ -1363,7 +1424,7 @@ (free-variables-in-env env)) (() type) ((quantifiers ...) - (for-all-type quantifiers type pred))) + (for-all-type quantifiers (qualified-type type pred)))) type)) (define (instantiate for-all) @@ -1372,9 +1433,16 @@ (extend-env var (fresh-type-variable) env)) (empty-env) (for-all-type-quantifiers for-all))) + (define type (for-all-type-ref for-all)) (values - (apply-substitutions-to-type (for-all-type-ref for-all) subs) - (apply-substitutions-to-predicate (for-all-type-predicate for-all) subs))) + (apply-substitutions-to-type (if (qualified-type? type) + (qualified-type-ref type) + type) + subs) + (if (qualified-type? type) + (apply-substitutions-to-predicate (qualified-type-predicate type) + subs) + #t))) (define (maybe-instantiate types) (define types+preds @@ -1391,6 +1459,11 @@ '() (error "primitive type mismatch" a b))) +(define (unify:structs a b) + (if (equal? a b) + '() + (error "struct type mismatch" a b))) + (define (unify:variable a b) (cond ((eq? a b) @@ -1422,6 +1495,8 @@ (match (list a b) (((? primitive-type? a) (? primitive-type? b)) (unify:primitives a b)) + (((? struct-type? a) (? struct-type? b)) + (unify:structs a b)) ((or ((? type-variable? a) b) (b (? type-variable? a))) (unify:variable a b)) @@ -1429,6 +1504,8 @@ (unify:functions a b)) (((? outputs-type?) (? outputs-type?)) '()) + (((? type?) (? type?)) + (error "type mismatch" a b)) ((() ()) '()) (((a rest-a ...) (b rest-b ...)) @@ -1527,7 +1604,7 @@ (eval-predicate* body-pred body-subs)) (values (texp (list (generalize (function-type (apply-substitutions-to-types param-types - body-subs) + subs) (texp-types body*)) pred env)) `(lambda ,params ,body*)) @@ -1604,6 +1681,17 @@ combined-subs pred)) +(define (infer:struct-ref exp member env) + (define-values (exp* exp-subs exp-pred) + (infer-exp exp env)) + (define exp-type (single-type exp*)) + (define tvar (fresh-type-variable)) + (values (texp (list tvar) + `(struct-ref ,exp* ,member)) + exp-subs + (compose-predicates exp-pred + (predicate:struct-has? exp-type member tvar)))) + (define (infer:let names exps body env) (define-values (exps* exp-subs exp-pred) (infer:list exps env)) @@ -1643,31 +1731,37 @@ ;; Eval predicate. (define-values (pred combined-subs) (eval-predicate* exp-pred (compose-substitutions exp-subs unify-subs))) - (values (texp (map single-type exps*) + (values (texp (list outputs-type) `(outputs ,@(map list names exps*))) combined-subs pred)) (define (infer:top-level bindings body env) - (define (infer-bindings bindings texps subs pred) + (define (infer-bindings bindings texps subs pred env) (match bindings (() - (values (reverse texps) subs pred)) - ((('function _ exp) . rest) + (values (reverse texps) subs pred env)) + ((('function name exp) . rest) (define-values (texp subs* pred*) (infer-exp exp env)) (define-values (new-pred combined-subs) (eval-predicate* (compose-predicates pred pred*) (compose-substitutions subs subs*))) + (define env* + (apply-substitutions-to-env (extend-env name (texp-types texp) env) + combined-subs)) (infer-bindings rest (cons texp texps) combined-subs - new-pred)) - (((_ type-name _) . rest) + new-pred + env*)) + (((_ type-name name) . rest) + (define types (list (type-name->type type-name))) (infer-bindings rest - (cons (list (type-name->type type-name)) texps) + (cons types texps) subs - pred)))) + pred + (extend-env name types env))))) (define qualifiers (map first bindings)) (define names (map (match-lambda @@ -1679,19 +1773,8 @@ (((? top-level-qualifier?) type-name _) type-name) (_ #f)) bindings)) - (define-values (exps exp-subs exp-pred) - (infer-bindings bindings '() '() #t)) - (define exp-types - (map (lambda (x) - (if (texp? x) - (texp-types x) - x)) - exps)) - (define env* - (fold extend-env - (apply-substitutions-to-env env exp-subs) - names - exp-types)) + (define-values (exps exp-subs exp-pred env*) + (infer-bindings bindings '() '() #t env)) (define-values (body* body-subs body-pred) (infer-exp body env*)) (define-values (pred combined-subs) @@ -1731,25 +1814,62 @@ (infer:primitive-call operator args env)) (('call operator args ...) (infer:call operator args env)) + (('struct-ref exp member) + (infer:struct-ref exp member env)) (('outputs (names exps) ...) (infer:outputs names exps env)) (('top-level bindings body) (infer:top-level bindings body env)) (_ (error "unknown form" exp)))) +;; Built-in types: +(define int-type (primitive-type 'int)) +(define float-type (primitive-type 'float)) +(define bool-type (primitive-type 'bool)) +(define-struct-type (vec2-type vec2) + (float-type x) + (float-type y)) +(define-struct-type (vec3-type vec3) + (float-type x) + (float-type y) + (float-type z)) +(define-struct-type (vec4-type vec4) + (float-type x) + (float-type y) + (float-type z) + (float-type w)) +;; TODO: Matrices are technically array types in GLSL, but we are +;; choosing to represent them opaquely for now to keep things simple. +(define mat3-type (primitive-type 'mat3)) +(define mat4-type (primitive-type 'mat4)) +(define sampler-2d-type (primitive-type 'sampler-2d)) + +(define (type-name->type name) + (case name + ((bool) bool-type) + ((int) int-type) + ((float) float-type) + ((vec2) vec2-type) + ((vec3) vec3-type) + ((vec4) vec4-type) + ((mat3) mat3-type) + ((mat4) mat4-type) + ((sampler-2d) sampler-2d-type))) + (define add/sub-type (let ((a (fresh-type-variable))) (list (for-all-type (list a) - (function-type (list a a) (list a)) - (predicate:or - (predicate:= a int-type) - (predicate:= a float-type) - (predicate:= a vec2-type) - (predicate:= a vec3-type) - (predicate:= a vec4-type) - (predicate:= a mat3-type) - (predicate:= a mat4-type)))))) + (qualified-type + (function-type (list a a) (list a)) + (predicate:or + (predicate:= a int-type) + (predicate:= a float-type) + (predicate:= a vec2-type) + (predicate:= a vec3-type) + (predicate:= a vec4-type) + (predicate:= a mat3-type) + (predicate:= a mat4-type))))))) (define-syntax-rule (a+b->c (ta tb tc) ...) (let ((a (fresh-type-variable)) @@ -1757,12 +1877,13 @@ (c (fresh-type-variable))) (list (for-all-type (list a b c) - (function-type (list a b) (list c)) - (predicate:or - (predicate:and (predicate:= a ta) - (predicate:= b tb) - (predicate:substitute c tc)) - ...))))) + (qualified-type + (function-type (list a b) (list c)) + (predicate:or + (predicate:and (predicate:= a ta) + (predicate:= b tb) + (predicate:substitute c tc)) + ...)))))) (define mul-type (a+b->c (int-type int-type int-type) @@ -1777,11 +1898,11 @@ (vec4-type float-type vec4-type) (float-type vec4-type vec4-type) (mat3-type mat3-type mat3-type) - (mat3-type float-type mat3-type) - (float-type mat3-type mat3-type) + (mat3-type vec3-type mat3-type) + (vec3-type mat3-type mat3-type) (mat4-type mat4-type mat4-type) - (mat4-type float-type mat4-type) - (float-type mat4-type mat4-type))) + (mat4-type vec4-type vec4-type) + (vec4-type mat4-type vec4-type))) (define div-type (a+b->c (int-type int-type int-type) @@ -1805,10 +1926,11 @@ (let ((a (fresh-type-variable))) (list (for-all-type (list a) - (function-type (list a a) (list bool-type)) - (predicate:or - (predicate:= a int-type) - (predicate:= a float-type)))))) + (qualified-type + (function-type (list a a) (list bool-type)) + (predicate:or + (predicate:= a int-type) + (predicate:= a float-type))))))) (define not-type (list (function-type (list bool-type) (list bool-type)))) @@ -1829,28 +1951,31 @@ (let ((a (fresh-type-variable))) (list (for-all-type (list a) - (function-type (list a) (list a)) - (predicate:or - (predicate:= a int-type) - (predicate:= a float-type)))))) + (qualified-type + (function-type (list a) (list a)) + (predicate:or + (predicate:= a int-type) + (predicate:= a float-type))))))) (define sqrt-type (let ((a (fresh-type-variable))) (list (for-all-type (list a) - (function-type (list a) (list a)) - (predicate:or - (predicate:= a int-type) - (predicate:= a float-type)))))) + (qualified-type + (function-type (list a) (list a)) + (predicate:or + (predicate:= a int-type) + (predicate:= a float-type))))))) (define min/max-type (let ((a (fresh-type-variable))) (list (for-all-type (list a) - (function-type (list a a) (list a)) - (predicate:or - (predicate:= a int-type) - (predicate:= a float-type)))))) + (qualified-type + (function-type (list a a) (list a)) + (predicate:or + (predicate:= a int-type) + (predicate:= a float-type))))))) (define trigonometry-type (list (function-type (list float-type) (list float-type)))) @@ -1859,10 +1984,11 @@ (let ((a (fresh-type-variable))) (list (for-all-type (list a) - (function-type (list a a) (list a)) - (predicate:or - (predicate:= a int-type) - (predicate:= a float-type)))))) + (qualified-type + (function-type (list a a) (list a)) + (predicate:or + (predicate:= a int-type) + (predicate:= a float-type))))))) (define texture-2d-ref-type (list (function-type (list sampler-2d-type vec2-type) @@ -1928,17 +2054,28 @@ ;; Compute all the valid permutations of substitutions that a ;; predicate could produce. -(define (possible-substitutions pred) +(define (possible-substitutions pred structs) (match pred (#t '()) (('substitute from to) (list (list (cons from to)))) (('= a b) (list (list (cons a b)))) + (('struct-has? struct-var member member-var) + (filter-map (lambda (struct) + (let ((member-type (struct-type-ref struct member))) + (and member-type + (list (cons struct-var struct) + (cons member-var member-type))))) + structs)) (('or preds ...) - (concatenate (map possible-substitutions preds))) + (concatenate (map (lambda (pred) + (possible-substitutions pred structs)) + preds))) (((or 'and 'list) preds ...) - (let loop ((in (map possible-substitutions preds))) + (let loop ((in (map (lambda (pred) + (possible-substitutions pred structs)) + preds))) (match in (() '()) ((options . rest) @@ -1952,7 +2089,19 @@ rest-options))) options))))))) +(define (find-structs exp) + (match exp + ((? struct-type?) + (list exp)) + ((exps ...) + (append-map find-structs exps)) + (_ '()))) + (define (resolve-overloads program) + ;; Find all of the struct types used in the program. They will be + ;; used to generate overloaded functions that take one or more + ;; structs as arguments. + (define structs (delete-duplicates (find-structs program))) (match program (('t types ('top-level bindings body)) (define bindings* @@ -1960,7 +2109,8 @@ (match bindings (() '()) ((('function name ('t ((? for-all-type? type)) func)) . rest) - (define func-type (for-all-type-ref type)) + (define qtype (for-all-type-ref type)) + (define func-type (qualified-type-ref qtype)) (append (map (lambda (subs) (define type* (apply-substitutions-to-type func-type subs)) @@ -1968,7 +2118,9 @@ (apply-substitutions-to-exp func subs)) `(function ,name (t (,type*) ,func*))) (delete-duplicates - (possible-substitutions (for-all-type-predicate type)))) + (possible-substitutions + (qualified-type-predicate qtype) + structs))) (loop rest))) ((binding . rest) (cons binding (loop rest)))))) @@ -1983,7 +2135,9 @@ ;; Transform a fully typed Seagull program into a string of GLSL code. (define (type-name type) - (primitive-type-name type)) + (if (struct-type? type) + (struct-type-name type) + (primitive-type-name type))) (define (single-temp temps) (match temps @@ -2157,6 +2311,19 @@ ", ")) output-temps) +(define (emit:struct-ref type exp member version port level) + (define input-temp + (match (emit-glsl exp version port level) + ((temp) temp))) + (define output-temp (unique-identifier)) + (indent level port) + (format port "~a ~a = ~a.~a;\n" + (type-name type) + output-temp + input-temp + member) + (list output-temp)) + (define %type-name-map '((sampler-2d . sampler2D))) @@ -2215,6 +2382,8 @@ (emit:primcall type op args version port level)) (('t types ('call operator args ...)) (emit:call types operator args version port level)) + (('t (type) ('struct-ref exp member)) + (emit:struct-ref type exp member version port level)) (('t _ ('outputs (names exps) ...)) (emit:outputs names exps version port level)) (('t _ ('top-level (bindings ...) body)) -- cgit v1.2.3