;;; Chickadee Game Toolkit ;;; Copyright © 2023 David Thompson ;;; ;;; Chickadee is free software: you can redistribute it and/or modify ;;; it under the terms of the GNU General Public License as published ;;; by the Free Software Foundation, either version 3 of the License, ;;; or (at your option) any later version. ;;; ;;; Chickadee is distributed in the hope that it will be useful, but ;;; WITHOUT ANY WARRANTY; without even the implied warranty of ;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU ;;; General Public License for more details. ;;; ;;; You should have received a copy of the GNU General Public License ;;; along with this program. If not, see ;;; . ;;; Commentary: ;; ;; The Seagull shading language is a purely functional, statically ;; typed, Scheme-like language that can be compiled to GLSL code. The ;; reality of how GPUs work imposes some significant language ;; restrictions, but they are restrictions anyone who writes shader ;; code is already used to. ;; ;; Features: ;; - Purely functional ;; - Vertex and fragment shader output ;; - Targets multiple GLSL versions ;; - Type inference ;; - Lexical scoping ;; - Nested functions ;; - Multiple return values ;; ;; Limitations: ;; - No first-class functions ;; - No closures ;; - No recursion ;; ;; TODO: ;; - Loops ;; - discard ;; - (define ...) form ;; - struct field aliases (rgba for vec4, for example) ;; - Scheme shader type -> GLSL struct translation ;; - Dead code elimination (error when a uniform is eliminated) ;; - User defined structs ;; - Multiple GLSL versions ;; - Better error messages (especially around type predicate failure) ;; - Refactor to add define-primitive syntax ;; ;;; Code: (define-module (chickadee graphics seagull) #:use-module (chickadee graphics shader) #:use-module (ice-9 exceptions) #:use-module (ice-9 format) #:use-module (ice-9 match) #:use-module (srfi srfi-1) #:use-module (srfi srfi-9) #:use-module (srfi srfi-11) #:export (compile-seagull-module compile-shader link-seagull-modules define-vertex-shader define-fragment-shader seagull-module? seagull-module-vertex? seagull-module-fragment? seagull-module-stage seagull-module-inputs seagull-module-outputs seagull-module-uniforms seagull-module-source seagull-module-compiled seagull-module-global-map)) ;; The Seagull compiler is designed as a series of source-to-source ;; program transformations in which each transformation pass results ;; in a program that is one step closer to being directly emitted to ;; GLSL code. ;;; ;;; Compiler helpers ;;; ;; This is where we keep miscellaneous code that is useful for many ;; stages of the compiler. (define (float? x) (and (number? x) (inexact? x))) ;; Immediate types are fundamental data types that need no ;; compilation. (define (immediate? x) (or (exact-integer? x) (float? x) (boolean? x))) (define (unary-operator? x) (eq? x 'not)) (define (arithmetic-operator? x) (memq x '(+ - * /))) (define (comparison-operator? x) (memq x '(= < <= > >=))) (define (binary-operator? x) (or (arithmetic-operator? x) (comparison-operator? x))) (define (vector-constructor? x) (memq x '(vec2 vec3 vec4))) (define (conversion? x) (memq x '(int->float float->int))) (define (math-function? x) (memq x '(abs sqrt min max mod floor ceil sin cos tan clamp mix length))) (define (vertex-primitive-call? x) #f) (define (fragment-primitive-call? x) (memq x '(texture-2d))) (define (primitive-call? x stage) (or (binary-operator? x) (unary-operator? x) (vector-constructor? x) (conversion? x) (math-function? x) (case stage ((vertex) (vertex-primitive-call? x)) ((fragment) (fragment-primitive-call? x))))) (define (top-level-qualifier? x) (memq x '(in out uniform))) (define (built-in-output? name stage) (case stage ((vertex) ;; GL 4+ has more built-ins, but we are supporting GL 2+ so we ;; can't use them easily. (memq name '(vertex:position vertex:point-size vertex:clip-distance))) ((fragment) (memq name '(vertex:frag-depth))))) ;;; ;;; Lexical environments ;;; ;; Environments keep track of the variables that are in scope of an ;; expression. (define (empty-env) '()) (define-syntax-rule (new-env (key value) ...) (list (cons key value) ...)) (define &seagull-unbound-variable-error (make-exception-type '&seagull-unbound-variable-error &error '(name))) (define make-seagull-unbound-variable-error (record-constructor &seagull-unbound-variable-error)) (define seagull-unbound-variable-name (exception-accessor &seagull-unbound-variable-error (record-accessor &seagull-unbound-variable-error 'name))) (define (lookup name env) (or (assq-ref env name) (raise-exception (make-exception (make-seagull-unbound-variable-error name) (make-exception-with-origin lookup) (make-exception-with-message "seagull: unbound variable") (make-exception-with-irritants (list name env)))))) (define (lookup* name env) (assq-ref env name)) (define (lookup-all names env) (map (lambda (name) (lookup name env)) names)) (define (extend-env name value env) (alist-cons name value env)) (define (compose-envs . envs) (concatenate envs)) (define (env-names env) (map car env)) (define (env-values env) (map cdr env)) (define (env-map proc env) (map (match-lambda ((name . exp) (proc name exp))) env)) (define (env-fold proc init env) (fold (lambda (e memo) (match e ((name . exp) (proc name exp memo)))) init env)) (define (env-for-each proc env) (for-each (match-lambda ((name . exp) (proc name exp))) env)) (define (top-level-env) (empty-env)) ;;; ;;; Macro expansion and alpha conversion ;;; ;; Macro expansion converts convenient but non-primitive syntax forms ;; (such as let*) into primitive syntax. Seagull does not currently ;; support user defined macros, just a set of built-ins. ;; ;; Alpha conversion is the process of converting all the user defined ;; identifiers in a program to uniquely named identifiers. This ;; process frees the compiler from having to worry about things like ;; '+' being a user defined variable that shadows the primitive ;; addition operation. (define &seagull-syntax-error (make-exception-type '&seagull-syntax-error &error '(form))) (define make-seagull-syntax-error (record-constructor &seagull-syntax-error)) (define seagull-syntax-form (exception-accessor &seagull-syntax-error (record-accessor &seagull-syntax-error 'form))) (define (seagull-syntax-error exp msg origin) (raise-exception (make-exception (make-seagull-syntax-error exp) (make-exception-with-origin origin) (make-exception-with-message (format #f "seagull syntax error: ~a" msg)) (make-exception-with-irritants (list exp))))) (define unique-identifier-counter (make-parameter 0)) (define (unique-identifier-number) (let ((n (unique-identifier-counter))) (unique-identifier-counter (+ n 1)) n)) (define (unique-identifier) (string->symbol (format #f "V~a" (unique-identifier-number)))) (define (unique-identifiers-for-list lst) (map (lambda (_x) (unique-identifier)) lst)) (define (alpha-convert names) (define names* (map (lambda (_name) (unique-identifier)) names)) (fold extend-env (empty-env) names names*)) (define (expand:list exps stage env) (map (lambda (exp) (expand exp stage env)) exps)) (define (expand:variable exp stage env) (lookup exp env)) (define (expand:if predicate consequent alternate stage env) `(if ,(expand predicate stage env) ,(expand consequent stage env) ,(expand alternate stage env))) (define (expand:let names exps body stage env) (if (null? names) (expand body stage env) (let* ((exps* (map (lambda (exp) (expand exp stage env)) exps)) (env* (compose-envs (alpha-convert names) env)) (bindings* (map list (lookup-all names env*) exps*))) `(let ,bindings* ,(expand body stage env*))))) (define (expand:let* bindings body stage env) (match bindings (() (expand body stage env)) ((binding . rest) (expand `(let (,binding) (let* ,rest ,body)) stage env)))) (define (expand:lambda params body stage env) (define env* (compose-envs (alpha-convert params) env)) (define params* (lookup-all params env*)) `(lambda ,params* ,(expand body stage env*))) (define (expand:values exps stage env) `(values ,@(expand:list exps stage env))) (define (expand:-> exp fields stage env) (define exp* (expand exp stage env)) (match fields ((field . rest) (let loop ((fields rest) (prev `(struct-ref ,exp* ,field))) (match fields (() prev) ((next . rest) (loop `(struct-ref ,prev ,next) rest))))))) (define (expand:@ exp indices stage env) (define exp* (expand exp stage env)) (match indices ((i . rest) (let loop ((indices rest) (prev `(array-ref ,exp* ,(expand i stage env)))) (match indices (() prev) ((j . rest) (loop `(array-ref ,prev ,(expand j stage env)) rest))))))) ;; Arithmetic operators, in true Lisp fashion, can accept many ;; arguments. + and * accept 0 or more. - and / accept one or more. ;; The expansion pass transforms all such expressions into binary ;; operator form. (define (expand:+ args stage env) (match args (() 0) ((n) (expand n stage env)) ((n . rest) `(primcall + ,(expand n stage env) ,(expand:+ rest stage env))))) (define (expand:- args stage env) (match args ((n) `(primcall - ,(expand n stage env) 0)) ((m n) `(primcall - ,(expand m stage env) ,(expand n stage env))) ((n . rest) `(primcall - ,(expand n stage env) ,(expand:- rest stage env))))) (define (expand:* args stage env) (match args (() 1) ((n) (expand n stage env)) ((n . rest) `(primcall * ,(expand n stage env) ,(expand:* rest stage env))))) (define (expand:/ args stage env) (match args ((n) `(primcall / 1 ,(expand n stage env))) ((m n) `(primcall / ,(expand m stage env) ,(expand n stage env))) ((m n . rest) (let loop ((rest rest) (exp `(primcall / ,(expand m stage env) ,(expand n stage env)))) (match rest ((l) `(primcall / ,exp ,(expand l stage env))) ((l . rest) (loop rest `(primcall / ,exp ,(expand l stage env))))))))) (define (expand:or exps stage env) (match exps (() #f) ((exp . rest) (expand `(let ((x ,exp)) (if x x (or ,@rest))) stage env)))) (define (expand:and exps stage env) (match exps (() #t) ((exp . rest) (expand `(let ((x ,exp)) (if x (and ,@rest) #f)) stage env)))) (define (expand:cond clauses stage env) (define (cond->if clauses*) (match clauses* ;; Our version of 'cond' requires a final else clause because the ;; static type checker enforces that both branches of an 'if' must ;; be the same type. If 'else' were optional then we wouldn't ;; know what type the final alternate branch should be. ((('else exp)) exp) (((predicate consequent) . rest) `(if ,predicate ,consequent ,(cond->if rest))) (() (seagull-syntax-error "'cond' form must end with 'else' clause" `(cond ,@clauses) expand:cond)) (_ (seagull-syntax-error "invalid 'cond' form" `(cond ,@clauses) expand:cond)))) (expand (cond->if clauses) stage env)) (define (expand:case key clauses stage env) (define (case->if clauses*) (match clauses* ;; Like 'cond', 'case' also requires a final 'else' clause. ((('else exp)) exp) ((((possibilities ..1) consequent) . rest) `(if (or ,@(map (lambda (n) `(= key ,n)) possibilities)) ,consequent ,(case->if rest))) (() (seagull-syntax-error "'case' form must end with 'else' clause" `(case ,key ,@clauses) expand:case)) (_ (seagull-syntax-error "invalid 'cond' form" `(case ,key ,@clauses) expand:case)))) (expand `(let ((key ,key)) ,(case->if clauses)) stage env)) (define (expand:primitive-call operator operands stage env) `(primcall ,operator ,@(expand:list operands stage env))) (define (expand:call operator operands stage env) `(call ,(expand operator stage env) ,@(expand:list operands stage env))) (define (expand:top-level qualifiers types names body stage env) (let* ((global-map (alpha-convert names)) (env* (compose-envs global-map env))) ;; TODO: Support interpolation qualifiers. (values `(top-level ,(map (lambda (qualifier type name) (list qualifier type (lookup name env*))) qualifiers types names) ,(expand body stage env*)) global-map))) (define (expand:outputs names exps stage env) `(outputs ,@(map (lambda (name exp) (list (if (built-in-output? name stage) name (lookup name env)) (expand exp stage env))) names exps))) (define (expand exp stage env) (define (primitive-call-for-stage? x) (primitive-call? x stage)) (match exp ;; Immediates and variables: ((? immediate?) exp) ((? symbol?) (expand:variable exp stage env)) ;; Primitive syntax forms: (('if predicate consequent alternate) (expand:if predicate consequent alternate stage env)) (('let (((? symbol? names) exps) ...) body) (expand:let names exps body stage env)) (('lambda ((? symbol? params) ...) body) (expand:lambda params body stage env)) (('values exps ...) (expand:values exps stage env)) (('outputs ((? symbol? names) exps) ...) (expand:outputs names exps stage env)) (('top-level (((? top-level-qualifier? qualifiers) types names) ...) body) (expand:top-level qualifiers types names body stage env)) ;; Macros: (('-> exp (? symbol? members) ..1) (expand:-> exp members stage env)) (('@ exp indices ...) (expand:@ exp indices stage env)) (('let* (bindings ...) body) (expand:let* bindings body stage env)) (('+ args ...) (expand:+ args stage env)) (('- args ...) (expand:- args stage env)) (('* args ...) (expand:* args stage env)) (('/ args ...) (expand:/ args stage env)) (('or exps ...) (expand:or exps stage env)) (('and exps ...) (expand:and exps stage env)) (('cond clauses ...) (expand:cond clauses stage env)) (('case key clauses ...) (expand:case key clauses stage env)) ;; Primitive calls: (((? primitive-call-for-stage? operator) args ...) (expand:primitive-call operator args stage env)) ;; Function calls: ((operator args ...) (expand:call operator args stage env)) ;; Syntax error: (_ (seagull-syntax-error "unknown form" exp expand)))) ;;; ;;; Constant propagation and partial evaluation ;;; ;; Replace references to constants (variables that store an immediate ;; value: integer, float, boolean) with the constants themselves. ;; Then look for opportunities to evaluate primitive expressions that ;; have constant arguments. This will make the type inferencer's job ;; a bit easier. (define (propagate:if predicate consequent alternate env) `(if ,(propagate-constants predicate env) ,(propagate-constants consequent env) ,(propagate-constants alternate env))) (define (propagate:lambda params body env) `(lambda ,params ,(propagate-constants body env))) (define (propagate:values exps env) `(values ,@(map (lambda (exp) (propagate-constants exp env)) exps))) (define (propagate:let names exps body env) (define exps* (map (lambda (exp) (propagate-constants exp env)) exps)) ;; Extend environment with known constants. (define env* (fold (lambda (name exp env*) (if (immediate? exp) (extend-env name exp env*) env*)) env names exps*)) ;; Drop all bindings for constant expressions. (define bindings (filter-map (lambda (name exp) (if (immediate? exp) #f (list name exp))) names exps*)) ;; If there are no bindings left, remove the 'let' entirely. (if (null? bindings) (propagate-constants body env*) `(let ,bindings ,(propagate-constants body env*)))) (define (propagate:primcall operator args env) `(primcall ,operator ,@(map (lambda (arg) (propagate-constants arg env)) args))) (define (propagate:call operator args env) `(call ,(propagate-constants operator env) ,@(map (lambda (arg) (propagate-constants arg env)) args))) (define (propagate:struct-ref exp field env) `(struct-ref ,(propagate-constants exp env) ,field)) (define (propagate:array-ref array-exp index-exp env) `(array-ref ,(propagate-constants array-exp env) ,(propagate-constants index-exp env))) ;; The division of two integers can result in a rational, non-integer, ;; such as 1/2. This isn't how integer division works in GLSL, so we ;; need to round the result to an integer. (define (glsl-divide x y) (let ((result (/ x y))) (if (or (float? result) (integer? result)) result (round result)))) (define (propagate:arithmetic op x y env) (define x* (propagate-constants x env)) (define y* (propagate-constants y env)) (if (or (and (exact-integer? x*) (exact-integer? y*)) (and (float? x*) (float? y*))) (let ((op* (case op ((+) +) ((-) -) ((*) *) ((/) glsl-divide)))) (op* x* y*)) `(primcall ,op ,x* ,y*))) (define (propagate:top-level inputs body env) `(top-level ,inputs ,(propagate-constants body env))) (define (propagate:outputs names exps env) `(outputs ,@(map (lambda (name exp) (list name (propagate-constants exp env))) names exps))) (define (propagate-constants exp env) (match exp ((? immediate?) exp) ((? symbol?) (or (lookup* exp env) exp)) (('if predicate consequent alternate) (propagate:if predicate consequent alternate env)) (('lambda (params ...) body) (propagate:lambda params body env)) (('values exps ...) (propagate:values exps env)) (('let ((names exps) ...) body) (propagate:let names exps body env)) (('primcall (and (or '+ '- '* '/) op) x y) (propagate:arithmetic op x y env)) (('primcall operator args ...) (propagate:primcall operator args env)) (('call operator args ...) (propagate:call operator args env)) (('struct-ref exp field) (propagate:struct-ref exp field env)) (('array-ref array-exp index-exp) (propagate:array-ref array-exp index-exp env)) (('outputs (names exps) ...) (propagate:outputs names exps env)) (('top-level inputs body) (propagate:top-level inputs body env)))) ;;; ;;; Function hoisting ;;; ;; Move all lambda bindings to the top-level. Unfortunately, GLSL ;; does not allow nested function definitions, so nested functions in ;; Seagull only allow free variable references for top-level ;; variables, such as shader inputs and uniforms. (define &seagull-scope-error (make-exception-type '&seagull-scope-error &error '(variable))) (define make-seagull-scope-error (record-constructor &seagull-scope-error)) (define seagull-scope-variable (exception-accessor &seagull-scope-error (record-accessor &seagull-scope-error 'variable))) (define (check-free-variables-in-list exps bound-vars top-level-vars) (every (lambda (exp) (check-free-variables exp bound-vars top-level-vars)) exps)) (define (check-free-variables exp bound-vars top-level-vars) (match exp ((? immediate?) #t) ((? symbol?) (or (memq exp bound-vars) ; bound vars: OK (memq exp top-level-vars) ; top-level vars: OK ;; Free variables that aren't top-level are not allowed because ;; GLSL doesn't support closures. (raise-exception (make-exception (make-seagull-scope-error exp) (make-exception-with-origin check-free-variables) (make-exception-with-message "seagull: free variable is not top-level") (make-exception-with-irritants (list exp)))))) (('if predicate consequent alternate) (and (check-free-variables predicate bound-vars top-level-vars) (check-free-variables consequent bound-vars top-level-vars) (check-free-variables alternate bound-vars top-level-vars))) (('let ((names exps) ...) body) (define bound-vars* (append names bound-vars)) (and (check-free-variables-in-list exps bound-vars* top-level-vars) (check-free-variables body bound-vars* top-level-vars))) (('lambda (params ...) body) (check-free-variables body params top-level-vars)) (('values exps ...) (check-free-variables-in-list exps bound-vars top-level-vars)) ((or ('primcall _ args ...) ('call args ...)) (check-free-variables-in-list args bound-vars top-level-vars)) (('struct-ref exp _) (check-free-variables exp bound-vars top-level-vars)) (('array-ref array-exp index-exp) (and (check-free-variables array-exp bound-vars top-level-vars) (check-free-variables index-exp bound-vars top-level-vars))) (('outputs (names exps) ...) (check-free-variables-in-list exps bound-vars top-level-vars)) (('top-level ((_ _ names) ...) body) (define bound-vars* (append names bound-vars)) (check-free-variables body bound-vars* top-level-vars)))) (define (hoist:list exps) (let-values (((exp-list env-list) (unzip2 (map (lambda (exp) (call-with-values (lambda () (hoist-functions exp)) list)) exps)))) (values exp-list (apply compose-envs env-list)))) (define (hoist:if predicate consequent alternate) (define-values (predicate* predicate-env) (hoist-functions predicate)) (define-values (consequent* consequent-env) (hoist-functions consequent)) (define-values (alternate* alternate-env) (hoist-functions alternate)) (values `(if ,predicate* ,consequent* ,alternate*) (compose-envs predicate-env consequent-env alternate-env))) (define (hoist:let names exps body) (define-values (exps* exps-env) (hoist:list exps)) (define-values (body* body-env) (hoist-functions body)) ;; Remove all lambda bindings... (define bindings (filter-map (lambda (name exp) (match exp (('lambda _ _) #f) (_ (list name exp)))) names exps*)) ;; ...and add them to the top-level environment. (define env* (fold (lambda (name exp env) (match exp (('lambda _ _) (extend-env name exp env)) (_ env))) (compose-envs exps-env body-env) names exps*)) ;; If there are no bindings left, remove the 'let'. (values (if (null? bindings) body* `(let ,bindings ,body*)) env*)) (define (hoist:lambda params body) (define-values (body* body-env) (hoist-functions body)) (values `(lambda ,params ,body*) body-env)) (define (hoist:values exps) (define-values (exps* exp-env) (hoist:list exps)) (values `(values ,@exps*) exp-env)) (define (hoist:primcall operator args) (define-values (args* args-env) (hoist:list args)) (values `(primcall ,operator ,@args*) args-env)) (define (hoist:call args) (define-values (args* args-env) (hoist:list args)) (values `(call ,@args*) args-env)) (define (hoist:struct-ref exp field) (define-values (exp* exp-env) (hoist-functions exp)) (values `(struct-ref ,exp* ,field) exp-env)) (define (hoist:array-ref array-exp index-exp) (define-values (array-exp* array-exp-env) (hoist-functions array-exp)) (define-values (index-exp* index-exp-env) (hoist-functions index-exp)) (values `(array-ref ,array-exp* ,index-exp*) (compose-envs array-exp-env index-exp-env))) (define (hoist:top-level inputs body) (define-values (body* body-env) (hoist-functions body)) (values `(top-level ,inputs ,body*) body-env)) (define (hoist:outputs names exps) (define-values (exps* exp-env) (hoist:list exps)) (values `(outputs ,@(map (lambda (name exp) (list name exp)) names exps*)) exp-env)) (define (hoist-functions exp) (match exp ((or (? immediate?) (? symbol?)) (values exp (empty-env))) (('if predicate consequent alternate) (hoist:if predicate consequent alternate)) (('let ((names exps) ...) body) (hoist:let names exps body)) (('lambda (params ...) body) (hoist:lambda params body)) (('values exps ...) (hoist:values exps)) (('primcall operator args ...) (hoist:primcall operator args)) (('call args ...) (hoist:call args)) (('struct-ref exp member) (hoist:struct-ref exp member)) (('array-ref array-exp index-exp) (hoist:array-ref array-exp index-exp)) (('outputs (names exps) ...) (hoist:outputs names exps)) (('top-level inputs body) (hoist:top-level inputs body)))) (define (maybe-merge-top-levels new-bindings exp) (match exp (('top-level bindings body) `(top-level ,(append bindings new-bindings) ,body)) (_ `(top-level ,new-bindings ,exp)))) (define (hoist-functions* exp) (define-values (exp* function-env) (hoist-functions exp)) (define top-level-vars (append (env-names function-env) (map (match-lambda ((_ _ name) name)) (match exp* (('top-level bindings _) bindings) (_ '()))))) (env-for-each (lambda (name exp) (check-free-variables exp '() top-level-vars)) function-env) (define bindings (env-map (lambda (name func) `(function ,name ,func)) function-env)) (maybe-merge-top-levels bindings exp*)) ;;; ;;; Type inference ;;; ;; Walk the expression tree of a type annotated program and solve for ;; all of the type variables using a variant of the Hindley-Milner ;; type inference algorithm extended to handle qualified types (types ;; with predicates.) GLSL is a statically typed language, but thanks ;; to type inference the user doesn't have to specify any types expect ;; for shader inputs, outputs, and uniforms. ;; Primitive types: (define (primitive-type name) `(primitive ,name)) (define (primitive-type? obj) (match obj (('primitive _) #t) (_ #f))) (define (primitive-type-name type) (match type (('primitive name) name))) ;; Outputs type: (define type:outputs '(outputs)) (define (outputs-type? obj) (equal? obj type:outputs)) ;; Struct type: (define (struct-type name fields) `(struct ,name ,fields)) (define (struct-type? obj) (match obj (('struct _ _) #t) (_ #f))) (define (struct-type-name type) (match type (('struct name _) name))) (define (struct-type-fields type) (match type (('struct _ fields) fields))) (define (struct-type-ref type field) (assq-ref (struct-type-fields type) field)) (define-syntax-rule (define-struct-type (var-name name) (types names) ...) (define var-name (struct-type 'name (list (cons 'names types) ...)))) ;; Array type: (define (array-type type length) `(array ,type ,length)) (define (array-type? type) (match type (('array _ _) #t) (_ #f))) (define (array-type-ref type) (match type (('array type _) type))) (define (array-type-length type) (match type (('array _ n) n))) ;; Type variables: (define unique-variable-type-counter (make-parameter 0)) (define (unique-variable-type-number) (let ((n (unique-variable-type-counter))) (unique-variable-type-counter (+ n 1)) n)) (define (unique-variable-type-name) (string->symbol (format #f "T~a" (unique-variable-type-number)))) (define (variable-type name) `(tvar ,name)) (define (fresh-variable-type) (variable-type (unique-variable-type-name))) (define (fresh-variable-types-for-list lst) (map (lambda (_x) (fresh-variable-type)) lst)) (define (variable-type? obj) (match obj (('tvar _) #t) (_ #f))) ;; Function types: (define (function-type parameters returns) `(-> ,parameters ,returns)) (define (function-type? obj) (match obj (('-> _ _) #t) (_ #f))) (define (function-type-parameters type) (match type (('-> params _) params))) (define (function-type-returns type) (match type (('-> _ returns) returns))) ;; Type schemes: (define (type-scheme quantifiers type) `(type-scheme ,quantifiers ,type)) (define (type-scheme? obj) (match obj (('type-scheme _ _) #t) (_ #f))) (define (type-scheme-quantifiers type) (match type (('type-scheme q _) q))) (define (type-scheme-ref type) (match type (('type-scheme _ t) t))) ;; Qualified types: (define (qualified-type type pred) `(qualified ,type ,pred)) (define (qualified-type? obj) (match obj (('qualified _ _) #t) (_ #f))) (define (qualified-type-ref type) (match type (('qualified type _) type))) (define (qualified-type-predicate type) (match type (('qualified _ pred) pred))) (define (type? obj) (or (primitive-type? obj) (variable-type? obj) (function-type? obj) (struct-type? obj) (outputs-type? obj))) (define (apply-substitution-to-type type from to) (cond ((or (primitive-type? type) (struct-type? type) (outputs-type? type)) type) ((variable-type? type) (if (equal? type from) to type)) ((function-type? type) (function-type (map (lambda (param-type) (apply-substitution-to-type param-type from to)) (function-type-parameters type)) (map (lambda (return-type) (apply-substitution-to-type return-type from to)) (function-type-returns type)))) ((array-type? type) (array-type (apply-substitution-to-type (array-type-ref type) from to) (array-type-length type))) ((type-scheme? type) type) (else (error "invalid type" type)))) (define (apply-substitutions-to-type type subs) (env-fold (lambda (from to type*) (apply-substitution-to-type type* from to)) type subs)) (define (apply-substitutions-to-types types subs) (map (lambda (type) (apply-substitutions-to-type type subs)) types)) (define (apply-substitution-to-env env from to) (env-fold (lambda (name types env*) (extend-env name (map (lambda (type) (apply-substitution-to-type type from to)) types) env*)) (empty-env) env)) (define (apply-substitutions-to-env env subs) (env-fold (lambda (from to env*) (apply-substitution-to-env env* from to)) env subs)) (define (apply-substitutions-to-texp t subs) (texp (apply-substitutions-to-types (texp-types t) subs) (texp-exp t))) (define (apply-substitutions-to-exp exp subs) (match exp ((? type?) (apply-substitutions-to-type exp subs)) ((exps ...) (map (lambda (exp) (apply-substitutions-to-exp exp subs)) exps)) (_ exp))) ;; Typed expressions: (define (texp types exp) `(t ,types ,exp)) (define (texp? obj) (match obj (('t _ _) #t) (_ #f))) (define (texp-types texp) (match texp (('t types _) types))) (define (texp-exp texp) (match texp (('t _ exp) exp))) (define (single-type texp) (match (texp-types texp) ((type) type) (_ (error "expected only 1 type" texp)))) (define &seagull-type-error (make-exception-type '&seagull-type-error &error '())) (define make-seagull-type-error (record-constructor &seagull-type-error)) (define (seagull-type-error msg args origin) (raise-exception (make-exception (make-seagull-type-error) (make-exception-with-origin origin) (make-exception-with-message (format #f "seagull type error: ~a" msg)) (make-exception-with-irritants args)))) (define (occurs? a b) (cond ((and (variable-type? a) (variable-type? b)) (eq? a b)) ((and (variable-type? a) (function-type? b)) (or (occurs? a (function-type-parameters b)) (occurs? a (function-type-returns b)))) ((and (type? a) (list? b)) (any (lambda (b*) (occurs? a b*)) b)) (else #f))) (define (compose-substitutions a b) (define b* (map (match-lambda ((from . to) (cons from (apply-substitutions-to-type to a)))) b)) (define a* (filter-map (match-lambda ((from . to) (if (assq-ref b* from) #f (cons from to)))) a)) (append a* b*)) (define (lookup-type name env) (let ((type (lookup name env))) (if (type-scheme? type) (instantiate type) type))) (define (predicate:and . preds) ;; Combine inner 'and' predicates and remove #t predicates. (define preds* (let loop ((preds preds)) (match preds (() '()) ((('and sub-preds ...) . rest) (append sub-preds (loop rest))) ((#t . rest) (loop rest)) ((pred . rest) (cons pred (loop rest)))))) (match preds* (() #t) ((pred) pred) (_ `(and ,@preds*)))) (define (predicate:or . preds) (match preds (() #f) ((pred) pred) ((pred ('or sub-preds ...)) `(or ,pred ,@sub-preds)) (_ `(or ,@preds)))) (define (predicate:list . preds) (define preds* (let loop ((preds preds)) (match preds (() '()) ((('list sub-preds ...) . rest) (append sub-preds (loop rest))) ((pred . rest) (cons pred (loop rest)))))) `(list ,@preds*)) (define (predicate:= a b) `(= ,a ,b)) (define (predicate:substitute from to) `(substitute ,from ,to)) (define (predicate:substitutes subs) (apply predicate:and (map (match-lambda ((from . to) (predicate:substitute from to))) subs))) (define (predicate:struct-field struct field var) `(struct-field ,struct ,field ,var)) (define (predicate:array-element array var) `(array-element ,array ,var)) (define (compose-predicates a b) (cond ((and (eq? a #t) (eq? b #t)) #t) ((eq? a #t) b) ((eq? b #t) a) (else (predicate:list a b)))) (define (compose-predicates* preds) (reduce (lambda (pred memo) (compose-predicates pred memo)) #t preds)) (define (apply-substitution-to-predicate pred from to) (match pred (#t #t) (#f #f) (('= a b) `(= ,(apply-substitution-to-type a from to) ,(apply-substitution-to-type b from to))) (('and preds ...) `(and ,@(map (lambda (pred) (apply-substitution-to-predicate pred from to)) preds))) (('or preds ...) `(or ,@(map (lambda (pred) (apply-substitution-to-predicate pred from to)) preds))) (('list preds ...) `(list ,@(map (lambda (pred) (apply-substitution-to-predicate pred from to)) preds))) (('substitute a b) `(substitute ,(apply-substitution-to-type a from to) ,(apply-substitution-to-type b from to))) (('struct-field struct field var) `(struct-field ,(apply-substitution-to-type struct from to) ,field ,(apply-substitution-to-type var from to))) (('array-element array var) `(array-element ,(apply-substitution-to-type array from to) ,(apply-substitution-to-type var from to))))) (define (apply-substitutions-to-predicate pred subs) (env-fold (lambda (from to pred*) (apply-substitution-to-predicate pred* from to)) pred subs)) ;; Produces a simplified predicate and a new set of substitutions for ;; predicates that have been satisfied and simplified to #t. It's a ;; bit of a weird process since we're dealing with partial evaluation, ;; simplification, and constraint generation at the same time, so ;; there are lots of comments explaining what the heck is going on. (define (eval-predicate pred) (match pred ;; #t is the simplest predicate. It's always successful and ;; results in no substitutions. (#t (values #t '())) (#f (values #f '())) ;; '=' asserts that 'a' must equal 'b'. If either is a type ;; variable, then we don't have enough information to determine ;; success or failure. ((or ('= (? variable-type?) _) ('= _ (? variable-type?))) (values pred '())) ;; Neither argument is a type variable, so we can get an answer. (('= a b) (values (equal? a b) '())) ;; An 'or' succeeds if any of the predicates within succeed. (('or preds ...) (match preds ;; No clause succeeded or the 'or' had no clauses to begin ;; with. Either way, the 'or' fails. (() (values #f '())) ((pred* . rest) (define-values (new-pred subs) (eval-predicate pred*)) (match new-pred ;; This clause succeeded, so the entire 'or' can be ;; reduced to #t. (#t (values #t subs)) ;; This clause failed, so simplify the 'or' to just the ;; rest of the clauses. (#f (eval-predicate (apply predicate:or rest))) ;; There isn't enough information to determine if this ;; clause will succeed or fail. So, we evaluate the rest ;; of the 'or' clauses and compose the result with this ;; undetermined clause. (_ (define-values (rest-pred subs*) (eval-predicate (apply predicate:or rest))) (match rest-pred ;; All of the other 'or' clauses have failed, so the ;; 'or' can be removed entirely and reduced to the ;; undetermined predicate. (#f (values new-pred '())) ;; The rest of the 'or' succeeded, so we can form a ;; series of 'substitute' forms so that if this ;; undetermined clause fails the substitutions from the ;; rest of 'or' will still be respected. (#t (values (predicate:or new-pred (predicate:substitutes subs*)) '())) ;; The rest of the 'or' clauses have an undetermined ;; answer, so the result is an 'or'. (_ (values (predicate:or new-pred rest-pred) '())))))))) ;; An 'and' succeeds if *all* of the predicates within succeed. (('and preds ...) (match preds ;; All of the 'and' clauses succeed, so the 'and' succeeds. (() (values #t '())) ((pred* . rest) (define-values (new-pred subs) (eval-predicate pred*)) (match new-pred ;; The first 'and' clause is successful, so test the rest of ;; the clauses and compose the substitutes. (#t (let () (define-values (rest-pred subs*) (eval-predicate (apply predicate:and rest))) (match rest-pred ;; The rest of the 'and' clauses are successful, so the ;; entire 'and' is successful and we can return the ;; substitutions. (#t (values #t (compose-substitutions subs subs*))) ;; At least one of the remaining clauses has failed, so ;; the 'and' fails. (#f (values #f '())) ;; The rest of the 'and' is undetermined, so form a new ;; and that will perform the substitutions generated by ;; this clause if the rest of the 'and' eventually ;; succeeds. (_ (values (predicate:and rest-pred (predicate:substitutes subs)) '()) ;; (values (predicate:and rest-pred) ;; subs*) )))) ;; The clause failed, so the 'and' fails. (#f (values #f '())) ;; The clause is undetermined, so evaluate the rest of the ;; clauses and attempt to simply the resulting 'and' ;; expression. (_ (define-values (rest-pred subs*) (eval-predicate (apply predicate:and rest))) (match rest-pred ;; One of the remaining clauses fails, so even if the ;; undetermined clause were to succeed later, the 'and' ;; is going to fail so just fail now. (#f (values #f '())) ;; The remaining clauses pass, so we replace them with ;; substitutions that will occur should this ;; undertermined clause eventually succeed. (#t (values (predicate:and new-pred (predicate:substitutes subs*)) '())) ;; The remaining clauses have an undetermined result, so ;; construct a new 'and'. (_ (values (predicate:and new-pred rest-pred) '())))))))) ;; A 'list' predicate is like an 'or' except if any clause ;; succeeds the substitutions are propagated to the caller along ;; with a new list withou the successful clause. This is how ;; multiple independent predicates are composed. (('list preds ...) (match preds ;; No predicates, no failure. (() (values #t '())) ((pred* . rest) (define-values (new-pred subs) (eval-predicate pred*)) (match new-pred ;; This clause succeeded, so remove it from the result, eval ;; the rest of the clauses, and pass along the substitutions. (#t (let () (define-values (rest-pred subs*) (eval-predicate (apply predicate:list rest))) (values rest-pred (compose-substitutions subs subs*)))) ;; This clause failed, so the whole predicate fails. (#f (values #f '())) ;; There isn't enough information to determine if this ;; clause will succeed or fail. So, we evaluate the rest of ;; the clauses and compose the result with this undetermined ;; clause. (_ (define-values (rest-pred subs*) (eval-predicate (apply predicate:list rest))) (match rest-pred (#f (values #f '())) (#t (values new-pred subs*)) (_ (values (predicate:list pred* rest-pred) subs*)))))))) ;; Substitution always succeeds and returns a substitution to be ;; carried forward in the inference process. (('substitute a b) (values #t (list (cons a b)))) ;; Substitute the field var when struct has been resolved to a ;; struct type. (('struct-field struct field field-var) (if (struct-type? struct) (let ((field-type (struct-type-ref struct field))) (if field-type (values #t (list (cons field-var field-type))) (values #f '()))) (values pred '()))) ;; Substitute the element var when array has been resolved to an ;; array type. (('array-element array element-var) (if (array-type? array) (values #t (list (cons element-var (array-type-ref array)))) (values pred '()))))) (define (eval-predicate* pred subs) (define-values (new-pred pred-subs) (eval-predicate (apply-substitutions-to-predicate pred subs))) ;; TODO: Get information about *why* the predicate failed. (unless new-pred (error "predicate failure" pred)) ;; Recursively evaluate the predicate, applying the substitutions ;; generated by the last evaluation, until it cannot be simplified ;; any further. (if (null? pred-subs) (values new-pred subs) (eval-predicate* new-pred (compose-substitutions subs pred-subs)))) (define* (free-variables-in-type type) (cond ((or (primitive-type? type) (struct-type? type)) '()) ((array-type? type) (free-variables-in-type (array-type-ref type))) ((variable-type? type) (list type)) ((function-type? type) (let ((params (function-type-parameters type))) (filter (lambda (t) (member t params)) (delete-duplicates (append-map free-variables-in-type (function-type-returns type)))))) ((type-scheme? type) (fold delete (free-variables-in-type (type-scheme-ref type)) (type-scheme-quantifiers type))) (else (error "unknown type" type)))) (define (difference a b) (match a (() b) ((x . rest) (if (memq x b) (difference rest (delq x b)) (cons x (difference rest b)))))) (define (free-variables-in-type-scheme type-scheme) (difference (type-scheme-quantifiers type-scheme) (free-variables-in-type (type-scheme-ref type-scheme)))) (define (free-variables-in-env env) (delete-duplicates (env-fold (lambda (_name type vars) (cond ((variable-type? type) (cons (free-variables-in-type type) vars)) ((type-scheme? type) (cons (free-variables-in-type-scheme type) vars)) (else vars))) '() env))) (define (free-variables-in-predicate pred) (match pred (#t '()) (('= a b) (append (free-variables-in-type a) (free-variables-in-type b))) (('and preds ...) (append-map (lambda (pred) (free-variables-in-predicate pred)) preds)) (('or preds ...) (append-map (lambda (pred) (free-variables-in-predicate pred)) preds)) (('list preds ...) (append-map (lambda (pred) (free-variables-in-predicate pred)) preds)) (('substitute a b) (append (free-variables-in-type a) (free-variables-in-type b))) (('struct-field struct field var) (append (free-variables-in-type struct) (free-variables-in-type var))) (('array-element array var) (append (free-variables-in-type array) (free-variables-in-type var))))) ;; Quantified variables are type variables that appear free in the ;; function return types or in the predicate. (define (generalize type pred env) (if (function-type? type) (match (difference (delete-duplicates (append (free-variables-in-type type) (free-variables-in-predicate pred))) (free-variables-in-env env)) (() type) ((quantifiers ...) (type-scheme quantifiers (qualified-type type pred)))) type)) (define (instantiate type-scheme) (define subs (fold (lambda (var env) (extend-env var (fresh-variable-type) env)) (empty-env) (type-scheme-quantifiers type-scheme))) (define type (type-scheme-ref type-scheme)) (values (apply-substitutions-to-type (if (qualified-type? type) (qualified-type-ref type) type) subs) (if (qualified-type? type) (apply-substitutions-to-predicate (qualified-type-predicate type) subs) #t))) (define (maybe-instantiate types) (define types+preds (map (lambda (type) (if (type-scheme? type) (call-with-values (lambda () (instantiate type)) list) (list type #t))) types)) (values (map first types+preds) (reduce compose-predicates #t (map second types+preds)))) (define (unify:primitives a b) (if (equal? a b) '() (error "primitive type mismatch" a b))) (define (unify:structs a b) (if (equal? a b) '() (error "struct type mismatch" a b))) (define (unify:variable a b) (cond ((eq? a b) '()) ((occurs? a b) (error "type contains reference to itself" a b)) (else (list (cons a b))))) (define (unify:functions a b) (define param-subs (unify (function-type-parameters a) (function-type-parameters b))) (define return-subs (unify (apply-substitutions-to-types (function-type-returns a) param-subs) (apply-substitutions-to-types (function-type-returns b) param-subs))) (compose-substitutions param-subs return-subs)) (define (unify:lists a rest-a b rest-b) (define sub-first (unify a b)) (define sub-rest (unify (apply-substitutions-to-types rest-a sub-first) (apply-substitutions-to-types rest-b sub-first))) (compose-substitutions sub-first sub-rest)) (define (unify a b) (match (list a b) (((? primitive-type? a) (? primitive-type? b)) (unify:primitives a b)) (((? struct-type? a) (? struct-type? b)) (unify:structs a b)) ((or ((? variable-type? a) b) (b (? variable-type? a))) (unify:variable a b)) (((? function-type? a) (? function-type? b)) (unify:functions a b)) (((? outputs-type?) (? outputs-type?)) '()) (((? type?) (? type?)) (error "type mismatch" a b)) ((() ()) '()) (((a rest-a ...) (b rest-b ...)) (unify:lists a rest-a b rest-b)) (_ (error "type mismatch" a b)))) (define (infer:immediate x) (values (texp (list (cond ((exact-integer? x) type:int) ((float? x) type:float) ((boolean? x) type:bool))) x) '() #t)) (define (infer:variable name env) (define-values (types pred) (maybe-instantiate (lookup-type name env))) (values (texp types name) '() pred)) (define (infer:list exps env) (let loop ((exps exps) (texps '()) (subs '()) (pred #t)) (match exps (() (values (reverse texps) subs pred)) ((exp . rest) (define-values (texp subs* pred*) (infer-exp exp env)) (define-values (new-pred combined-subs) (eval-predicate* (compose-predicates pred pred*) (compose-substitutions subs subs*))) (loop rest (cons texp texps) combined-subs new-pred))))) (define (infer:if predicate consequent alternate env) ;; Infer predicate types and unify it with the boolean type. (define-values (predicate-texp predicate-subs predicate-pred) (infer-exp predicate env)) (define predicate-unify-subs (unify (texp-types predicate-texp) (list type:bool))) ;; Combine the substitutions and apply them to the environment. (define combined-subs-0 (compose-substitutions predicate-subs predicate-unify-subs)) (define env0 (apply-substitutions-to-env env combined-subs-0)) ;; Infer consequent and alternate types and unify them against each ;; other. Each branch of an 'if' should have the same type. (define-values (consequent-texp consequent-subs consequent-pred) (infer-exp consequent env0)) (define combined-subs-1 (compose-substitutions combined-subs-0 consequent-subs)) (define env1 (apply-substitutions-to-env env0 consequent-subs)) (define-values (alternate-texp alternate-subs alternate-pred) (infer-exp alternate env1)) (define combined-subs-2 (compose-substitutions combined-subs-1 alternate-subs)) ;; Eval combined predicate. (define-values (pred combined-subs-3) (eval-predicate* (compose-predicates predicate-pred (compose-predicates consequent-pred alternate-pred)) combined-subs-2)) ;; ;; Apply final set of substitutions to the types of both branches. (define consequent-texp* (apply-substitutions-to-texp consequent-texp combined-subs-3)) (define alternate-texp* (apply-substitutions-to-texp alternate-texp combined-subs-3)) (values (texp (texp-types consequent-texp) `(if ,predicate-texp ,consequent-texp ,alternate-texp)) combined-subs-3 pred)) (define (infer:lambda params body env) ;; Each function parameter gets a fresh type variable. (define param-types (fresh-variable-types-for-list params)) ;; The type environment is extended with the function parameters. (define env* (fold (lambda (param type env*) (extend-env param (list type) env*)) env params param-types)) (define-values (body* body-subs body-pred) (infer-exp body env*)) (define-values (pred subs) (eval-predicate* body-pred body-subs)) (values (texp (list (generalize (function-type (apply-substitutions-to-types param-types subs) (texp-types body*)) pred env)) `(lambda ,params ,body*)) subs #t)) (define (infer:primitive-call operator args env) ;; The type signature of primitive functions can be looked up ;; directly in the environment. Primitive functions may be ;; overloaded and need to be instantiated with fresh type variables. (define-values (types operator-pred) (maybe-instantiate (lookup-type operator env))) (define operator-type (match types ((type) type))) ;; Infer the arguments. (define-values (args* arg-subs arg-pred) (infer:list args env)) ;; Generate fresh type variables to unify against the return types ;; of the operator. (define return-vars (fresh-variable-types-for-list (function-type-returns operator-type))) (define call-subs (unify operator-type (function-type (map single-type args*) return-vars))) ;; Apply substitutions to the predicate and then eval it, producing ;; a simplified predicate and a set of substitutions. (define-values (pred combined-subs) (eval-predicate* (compose-predicates operator-pred arg-pred) (compose-substitutions arg-subs call-subs))) (values (texp (apply-substitutions-to-types return-vars combined-subs) `(primcall ,operator ,@(map (lambda (arg) (apply-substitutions-to-texp arg combined-subs)) args*))) combined-subs pred)) (define (infer:call operator args env) ;; The type signature of primitive functions can be looked up ;; directly in the environment. (define-values (operator* operator-subs operator-pred) (infer-exp operator env)) (define env* (apply-substitutions-to-env env operator-subs)) ;; Infer the arguments. (define-values (args* arg-subs arg-pred) (infer:list args env*)) (define combined-subs-0 (compose-substitutions operator-subs arg-subs)) ;; Generate fresh type variables to unify against the return types ;; of the operator. (define operator-type (single-type operator*)) (define return-vars (fresh-variable-types-for-list (function-type-returns operator-type))) (define call-subs (unify (apply-substitutions-to-type operator-type combined-subs-0) (function-type (apply-substitutions-to-types (map single-type args*) combined-subs-0) return-vars))) ;; Eval predicate. (define-values (pred combined-subs) (eval-predicate* (compose-predicates operator-pred arg-pred) (compose-substitutions combined-subs-0 call-subs))) (values (texp (apply-substitutions-to-types return-vars combined-subs) `(call ,(apply-substitutions-to-texp operator* combined-subs) ,@(map (lambda (arg) (apply-substitutions-to-texp arg combined-subs)) args*))) combined-subs pred)) (define (infer:struct-ref exp field env) (define-values (exp* exp-subs exp-pred) (infer-exp exp env)) (define exp-type (single-type exp*)) (define tvar (fresh-variable-type)) (define-values (pred combined-subs) (eval-predicate* (compose-predicates (predicate:struct-field exp-type field tvar) exp-pred) exp-subs)) (values (texp (list (apply-substitutions-to-type tvar combined-subs)) `(struct-ref ,(apply-substitutions-to-texp exp* combined-subs) ,field)) combined-subs pred)) (define (infer:array-ref array-exp index-exp env) (define-values (array-exp* array-exp-subs array-exp-pred) (infer-exp array-exp env)) (define array-type (single-type array-exp*)) (define env* (apply-substitutions-to-env env array-exp-subs)) (define-values (index-exp* index-exp-subs index-exp-pred) (infer-exp index-exp env*)) (define index-type (single-type index-exp*)) (define combined-subs (compose-substitutions array-exp-subs index-exp-subs)) ;; Array indices must be integers. (define unify-subs (unify (apply-substitutions-to-type index-type combined-subs) type:int)) (define tvar (fresh-variable-type)) (define-values (pred subs) (eval-predicate* (compose-predicates (predicate:array-element array-type tvar) (compose-predicates array-exp-pred index-exp-pred)) (compose-substitutions combined-subs unify-subs))) (define array-exp** (apply-substitutions-to-texp array-exp* subs)) (define index-exp** (apply-substitutions-to-texp index-exp* subs)) (values (texp (list tvar) `(array-ref ,array-exp** ,index-exp**)) subs pred)) (define (infer:let names exps body env) (define-values (exps* exp-subs exp-pred) (infer:list exps env)) (define exp-types (map texp-types exps*)) (define env* (fold extend-env (apply-substitutions-to-env env exp-subs) names exp-types)) (define-values (body* body-subs body-pred) (infer-exp body env*)) (define-values (pred combined-subs) (eval-predicate* (compose-predicates exp-pred body-pred) (compose-substitutions exp-subs body-subs))) (values (texp (texp-types body*) `(let ,(map (lambda (name exp) (list name (apply-substitutions-to-texp exp combined-subs))) names exps*) ,(apply-substitutions-to-texp body* combined-subs))) combined-subs pred)) (define (infer:values exps env) (define-values (exps* exp-subs exp-pred) (infer:list exps env)) (values (texp (map single-type exps*) `(values ,@exps*)) exp-subs exp-pred)) (define (infer:outputs names exps env) (define-values (exps* exp-subs exp-pred) (infer:list exps env)) (define exp-types (map texp-types exps*)) (define unify-subs (unify (map texp-types exps*) (map (lambda (name) (lookup name env)) names))) ;; Eval predicate. (define-values (pred combined-subs) (eval-predicate* exp-pred (compose-substitutions exp-subs unify-subs))) (values (texp (list type:outputs) `(outputs ,@(map (lambda (name exp) (list name (apply-substitutions-to-texp exp combined-subs))) names exps*))) combined-subs pred)) (define (infer:top-level bindings body env) (define (infer-bindings bindings texps subs pred env) (match bindings (() (values (reverse texps) subs pred env)) ((('function name exp) . rest) (define-values (texp subs* pred*) (infer-exp exp env)) (define-values (new-pred combined-subs) (eval-predicate* (compose-predicates pred pred*) (compose-substitutions subs subs*))) (define env* (apply-substitutions-to-env (extend-env name (texp-types texp) env) combined-subs)) (infer-bindings rest (cons texp texps) combined-subs new-pred env*)) (((_ desc name) . rest) (define types (list (type-descriptor->type desc))) (infer-bindings rest (cons types texps) subs pred (extend-env name types env))))) (define qualifiers (map first bindings)) (define names (map (match-lambda (('function name _) name) ((_ _ name) name)) bindings)) (define type-names (map (match-lambda (((? top-level-qualifier?) type-name _) type-name) (_ #f)) bindings)) (define-values (exps exp-subs exp-pred env*) (infer-bindings bindings '() '() #t env)) (define-values (body* body-subs body-pred) (infer-exp body env*)) (define-values (pred combined-subs) (eval-predicate* (compose-predicates exp-pred body-pred) (compose-substitutions exp-subs body-subs))) (define bindings* (map (match-lambda* (((? top-level-qualifier? qualifier) type-name name _) (list qualifier type-name name)) (('function _ name exp) `(function ,name ,(apply-substitutions-to-exp exp combined-subs)))) qualifiers type-names names exps)) (values (texp (texp-types body*) `(top-level ,bindings* ,body*)) combined-subs pred)) ;; Inference returns 3 values: ;; - a typed expression ;; - a list of substitutions ;; - a type predicate (define (infer-exp exp env) (match exp ((? immediate?) (infer:immediate exp)) ((? symbol? name) (infer:variable name env)) (('if predicate consequent alternate) (infer:if predicate consequent alternate env)) (('let ((names exps) ...) body) (infer:let names exps body env)) (('lambda (params ...) body) (infer:lambda params body env)) (('values exps ...) (infer:values exps env)) (('primcall operator args ...) (infer:primitive-call operator args env)) (('call operator args ...) (infer:call operator args env)) (('struct-ref exp field) (infer:struct-ref exp field env)) (('array-ref array-exp index-exp) (infer:array-ref array-exp index-exp env)) (('outputs (names exps) ...) (infer:outputs names exps env)) (('top-level bindings body) (infer:top-level bindings body env)) (_ (error "unknown form" exp)))) ;; Built-in types: (define type:int (primitive-type 'int)) (define type:float (primitive-type 'float)) (define type:bool (primitive-type 'bool)) (define-struct-type (type:vec2 vec2) (type:float x) (type:float y)) (define-struct-type (type:vec3 vec3) (type:float x) (type:float y) (type:float z)) (define-struct-type (type:vec4 vec4) (type:float x) (type:float y) (type:float z) (type:float w)) ;; TODO: Matrices are technically array types in GLSL, but we are ;; choosing to represent them opaquely for now to keep things simple. (define type:mat3 (primitive-type 'mat3)) (define type:mat4 (primitive-type 'mat4)) (define type:sampler-2d (primitive-type 'sampler2D)) (define (type-descriptor->type desc) (match desc ('bool type:bool) ('int type:int) ('float type:float) ('vec2 type:vec2) ('vec3 type:vec3) ('vec4 type:vec4) ('mat3 type:mat3) ('mat4 type:mat4) ('sampler-2d type:sampler-2d) (('array desc* (? exact-integer? length) (? exact-integer? rest) ...) (let loop ((rest rest) (prev (array-type (type-descriptor->type desc*) length))) (match rest (() prev) ((length . rest) (loop rest (array-type prev length)))))))) (define-syntax-rule (a+b->c (ta tb tc) ...) (let ((a (fresh-variable-type)) (b (fresh-variable-type)) (c (fresh-variable-type))) (list (type-scheme (list a b c) (qualified-type (function-type (list a b) (list c)) (predicate:or (predicate:and (predicate:= a ta) (predicate:= b tb) (predicate:substitute c tc)) ...)))))) (define (top-level-type-env stage) (define type:+/- (let ((a (fresh-variable-type))) (list (type-scheme (list a) (qualified-type (function-type (list a a) (list a)) (predicate:or (predicate:= a type:int) (predicate:= a type:float) (predicate:= a type:vec2) (predicate:= a type:vec3) (predicate:= a type:vec4) (predicate:= a type:mat3) (predicate:= a type:mat4))))))) (define type:* (a+b->c (type:int type:int type:int) (type:float type:float type:float) (type:int type:float type:float) (type:float type:int type:float) (type:vec2 type:vec2 type:vec2) (type:vec2 type:float type:vec2) (type:float type:vec2 type:vec2) (type:vec3 type:vec3 type:vec3) (type:vec3 type:float type:vec3) (type:float type:vec3 type:vec3) (type:vec4 type:vec4 type:vec4) (type:vec4 type:float type:vec4) (type:float type:vec4 type:vec4) (type:mat3 type:mat3 type:mat3) (type:mat3 type:vec3 type:vec3) (type:vec3 type:mat3 type:vec3) (type:mat4 type:mat4 type:mat4) (type:mat4 type:vec4 type:vec4) (type:vec4 type:mat4 type:vec4))) (define type:/ (a+b->c (type:int type:int type:int) (type:float type:float type:float) (type:float type:int type:float) (type:int type:float type:float) (type:vec2 type:vec2 type:vec2) (type:vec2 type:float type:vec2) (type:vec3 type:vec3 type:vec3) (type:vec3 type:float type:vec3) (type:vec4 type:vec4 type:vec4) (type:vec4 type:float type:vec4) (type:mat3 type:float type:mat3) (type:mat4 type:float type:mat4))) (define type:mod (a+b->c (type:float type:float type:float) (type:int type:int type:float) (type:vec2 type:vec2 type:vec2) (type:vec3 type:vec3 type:vec3) (type:vec4 type:vec4 type:vec4) (type:vec2 type:float type:vec2) (type:vec3 type:float type:vec3) (type:vec4 type:float type:vec4))) (define type:floor/ceil (let ((a (fresh-variable-type))) (list (type-scheme (list a) (qualified-type (function-type (list a) (list a)) (predicate:or (predicate:= a type:float) (predicate:= a type:vec2) (predicate:= a type:vec3) (predicate:= a type:vec4))))))) (define type:int->float (list (function-type (list type:int) (list type:float)))) (define type:float->int (list (function-type (list type:float) (list type:int)))) (define type:comparison (let ((a (fresh-variable-type))) (list (type-scheme (list a) (qualified-type (function-type (list a a) (list type:bool)) (predicate:or (predicate:= a type:int) (predicate:= a type:float))))))) (define type:not (list (function-type (list type:bool) (list type:bool)))) (define type:make-vec2 (list (function-type (list type:float type:float) (list type:vec2)))) (define type:make-vec3 (list (function-type (list type:float type:float type:float) (list type:vec3)))) (define type:make-vec4 (list (function-type (list type:float type:float type:float type:float) (list type:vec4)))) (define type:length (let ((a (fresh-variable-type))) (list (type-scheme (list a) (qualified-type (function-type (list a) (list type:float)) (predicate:or (predicate:= a type:float) (predicate:= a type:vec2) (predicate:= a type:vec3) (predicate:= a type:vec4))))))) (define type:abs (let ((a (fresh-variable-type))) (list (type-scheme (list a) (qualified-type (function-type (list a) (list a)) (predicate:or (predicate:= a type:int) (predicate:= a type:float))))))) (define type:sqrt (let ((a (fresh-variable-type))) (list (type-scheme (list a) (qualified-type (function-type (list a) (list a)) (predicate:or (predicate:= a type:int) (predicate:= a type:float))))))) (define type:min/max (let ((a (fresh-variable-type))) (list (type-scheme (list a) (qualified-type (function-type (list a a) (list a)) (predicate:or (predicate:= a type:int) (predicate:= a type:float))))))) (define type:trig (list (function-type (list type:float) (list type:float)))) (define type:clamp (let ((a (fresh-variable-type))) (list (type-scheme (list a) (qualified-type (function-type (list a a a) (list a)) (predicate:or (predicate:= a type:int) (predicate:= a type:float))))))) (define type:mix (let ((a (fresh-variable-type)) (b (fresh-variable-type))) (list (type-scheme (list a) (qualified-type (function-type (list a a type:float) (list a)) (predicate:or (predicate:= a type:int) (predicate:= a type:float) (predicate:= a type:vec4))))))) (define type:texture-2d (list (function-type (list type:sampler-2d type:vec2) (list type:vec4)))) `((+ . ,type:+/-) (- . ,type:+/-) (* . ,type:*) (/ . ,type:/) (mod . ,type:mod) (floor . ,type:floor/ceil) (ceil . ,type:floor/ceil) (int->float . ,type:int->float) (float->int . ,type:float->int) (= . ,type:comparison) (< . ,type:comparison) (<= . ,type:comparison) (> . ,type:comparison) (>= . ,type:comparison) (not . ,type:not) (vec2 . ,type:make-vec2) (vec3 . ,type:make-vec3) (vec4 . ,type:make-vec4) (length . ,type:length) (abs . ,type:abs) (sqrt . ,type:sqrt) (min . ,type:min/max) (max . ,type:min/max) (sin . ,type:trig) (cos . ,type:trig) (tan . ,type:trig) (clamp . ,type:clamp) (mix . ,type:mix) ,@(case stage ((vertex) `((vertex:position ,type:vec4) (vertex:point-size ,type:float) (vertex:clip-distance ,type:float))) ((fragment) `((fragment:depth ,type:float) (texture-2d . ,type:texture-2d)))))) ;; TODO: Add some kind of context object that is threaded through the ;; inference process so that when a type error occurs we can show the ;; expression that caused it. (define (infer-types exp stage) (infer-exp exp (top-level-type-env stage))) ;;; ;;; Overloaded functions ;;; ;; Replace quantified functions ('type-scheme' expressions) with a series ;; of non-quantified function type specifications, one for each unique ;; type of call in the program. (define (find-signatures:list name texps) (append-map (lambda (texp) (find-signatures name texp)) texps)) (define (find-signatures:if name predicate consequent alternate) (append (find-signatures name predicate) (find-signatures name consequent) (find-signatures name alternate))) (define (find-signatures:let name binding-texps body) (append (find-signatures:list name binding-texps) (find-signatures name body))) (define (find-signatures:array-ref name array index) (append (find-signatures name array) (find-signatures name index))) (define (find-signatures name texp) (match (texp-exp texp) ((or (? immediate?) (? symbol?)) '()) (('if predicate consequent alternate) (find-signatures:if name predicate consequent alternate)) (('let ((_ exps) ...) body) (find-signatures:let name exps body)) (('values exps ...) (find-signatures:list name exps)) (('primcall _ args ...) (find-signatures:list name args)) (('call operator args ...) (cons (if (eq? (texp-exp operator) name) (function-type (map single-type args) (texp-types texp))) (find-signatures:list name args))) (('struct-ref struct _) (find-signatures name struct)) (('array-ref array index) (find-signatures:array-ref name array index)) (('outputs (_ exps) ...) (find-signatures:list name exps)) (_ (error "uh oh" texp)))) (define (vars->subs exp env) (match exp (('t ((? variable-type? tvar)) (? symbol? name)) (let ((type (lookup* name env))) (if type (list (cons tvar type)) '()))) ((head . rest) (delete-duplicates (append (vars->subs head env) (vars->subs rest env)))) (_ '()))) (define (untype x) (match x (('t (_ ...) exp) (untype exp)) ((exp . rest) (cons (untype exp) (untype rest))) (_ x))) (define (resolve-overloads program stage) ;; Find all of the struct types used in the program. They will be ;; used to generate overloaded functions that take one or more ;; structs as arguments. ;;(define structs (delete-duplicates (find-structs program))) (match program (('t types ('top-level bindings body)) (define bindings* (let loop ((bindings bindings) (globals (empty-env))) (match bindings (() '()) ((('function name ('t ((? type-scheme? type)) func)) . rest) (define qtype (type-scheme-ref type)) (define func-type (qualified-type-ref qtype)) (append (map (lambda (call-type) (define subs (unify func-type call-type)) (define type* (apply-substitutions-to-type func-type subs)) (define params (match func (('lambda (params ...) _) params))) (define env (compose-envs (fold extend-env (empty-env) params (map list (function-type-parameters type*))) globals)) (match func (('lambda _ body) (infer-exp (untype body) (compose-envs env (top-level-type-env stage))))) (define subs* (compose-substitutions subs (vars->subs func env))) (define func* (apply-substitutions-to-exp func subs*)) `(function ,name (t (,type*) ,func*))) (delete-duplicates (find-signatures name body))) (loop rest globals))) ((('function name texp) . rest) (cons `(function ,name ,texp) (loop rest globals))) (((qualifier type name) . rest) (cons (list qualifier type name) (loop rest (extend-env name (list (type-descriptor->type type)) globals))))))) `(t ,types (top-level ,bindings* ,body))) (_ (error "expected top-level form" program)))) ;;; ;;; GLSL emission ;;; ;; Transform a fully typed Seagull program into a string of GLSL code. (define (type-descriptor->glsl desc) (match desc ((? symbol?) (match (type-descriptor->type desc) ((? primitive-type? primitive) (primitive-type-name primitive)) ((? struct-type? struct) (struct-type-name struct)))) (('array desc* length) (format #f "~a[~a]" (type-descriptor->glsl desc*) length)))) (define (type->type-descriptor type) (cond ((primitive-type? type) (primitive-type-name type)) ((struct-type? type) (struct-type-name type)) ((array-type? type) `(array ,(type->type-descriptor (array-type-ref type)) ,(array-type-length type))))) (define (type->glsl type) (type-descriptor->glsl (type->type-descriptor type))) (define (single-temp temps) (match temps ((temp) temp))) (define (indent n port) (when (> n 0) (display (make-string (* n 2) #\space) port))) (define (emit:int n stage version port level) (define temp (unique-identifier)) (indent level port) (format port "int ~a = ~a;\n" temp n) (list temp)) (define (emit:float n stage version port level) (define temp (unique-identifier)) (indent level port) (format port "float ~a = ~a;\n" temp (if (inf? n) "1.0 / 0.0" n)) (list temp)) (define (emit:boolean b stage version port level) (define temp (unique-identifier)) (indent level port) (format port "bool ~a = ~a;\n" temp (if b "true" "false")) (list temp)) (define (emit:binary-operator type op a b stage version port level) (define op* (case op ((=) '==) (else op))) (define a-temp (single-temp (emit-glsl a stage version port level))) (define b-temp (single-temp (emit-glsl b stage version port level))) (define temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a ~a ~a;\n" (type->glsl type) temp a-temp op* b-temp) (list temp)) (define (emit:unary-operator type op a stage version port level) (define op* (case op ((not) '!) (else op))) (define a-temp (single-temp (emit-glsl a stage version port level))) (define temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a(~a);\n" (type->glsl type) temp op* a-temp) (list temp)) (define (emit:declaration type lhs rhs port level) (unless (outputs-type? type) (indent level port) (if rhs (format port "~a ~a = ~a;\n" (type->glsl type) lhs rhs) (format port "~a ~a;\n" (type->glsl type) lhs)))) (define (emit:declarations types lhs-list rhs-list port level) (define rhs-list* (if rhs-list rhs-list (make-list (length lhs-list) #f))) (for-each (lambda (type lhs rhs) (emit:declaration type lhs rhs port level)) types lhs-list rhs-list*)) (define (emit:mov a b port level) (when a (indent level port) (format port "~a = ~a;\n" a b))) (define (emit:function name type params body stage version port level) (define param-types (function-type-parameters type)) (define return-types (function-type-returns type)) (define outputs (unique-identifiers-for-list return-types)) (indent level port) (format port "void ~a(" name) (let loop ((params (append (zip (make-list (length params) 'in) param-types params) (zip (make-list (length return-types) 'out) return-types outputs))) (first? #t)) (match params (() #t) (((qualifier type name) . rest) (unless first? (display ", " port)) (format port "~a ~a ~a" qualifier (type->glsl type) name) (loop rest #f)))) (display ") {\n" port) (define body-temps (emit-glsl body stage version port (+ level 1))) (for-each (lambda (output temp) (emit:mov output temp port (+ level 1))) outputs body-temps) (indent level port) (display "}\n" port)) (define (emit:if predicate consequent alternate stage version port level) (define if-temps (if (equal? (texp-types consequent) (list type:outputs)) '(#f) (unique-identifiers-for-list (texp-types consequent)))) (emit:declarations (texp-types consequent) if-temps #f port level) (define predicate-temp (single-temp (emit-glsl predicate stage version port level))) (indent level port) (format port "if(~a) {\n" predicate-temp) (define consequent-temps (emit-glsl consequent stage version port (+ level 1))) (for-each (lambda (lhs rhs) (emit:mov lhs rhs port (+ level 1))) if-temps consequent-temps) (indent level port) (display "} else {\n" port) (define alternate-temps (emit-glsl alternate stage version port (+ level 1))) (for-each (lambda (lhs rhs) (emit:mov lhs rhs port (+ level 1))) if-temps alternate-temps) (indent level port) (display "}\n" port) if-temps) (define (emit:values exps stage version port level) (append-map (lambda (exp) (emit-glsl exp stage version port level)) exps)) (define (emit:let types names exps body stage version port level) (define binding-temps (map (lambda (exp) (single-temp (emit-glsl exp stage version port level))) exps)) (define binding-types (map single-type exps)) (emit:declarations binding-types names binding-temps port level) (define body-temps (emit-glsl body stage version port level)) (define let-temps (unique-identifiers-for-list types)) (emit:declarations (texp-types body) let-temps body-temps port level) let-temps) (define %primcall-map '((float->int . int) (int->float . float) (texture-2d . texture2D))) (define (emit:primcall type operator args stage version port level) (define operator* (or (assq-ref %primcall-map operator) operator)) (define arg-temps (map (lambda (arg) (single-temp (emit-glsl arg stage version port level))) args)) (define output-temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a(~a);\n" (type->glsl type) output-temp operator* (string-join (map symbol->string arg-temps) ", ")) (list output-temp)) (define (emit:call types operator args stage version port level) (define operator-name (single-temp (emit-glsl operator stage version port))) (define arg-temps (map (lambda (arg) (single-temp (emit-glsl arg stage version port level))) args)) (define output-temps (unique-identifiers-for-list types)) (emit:declarations types output-temps #f port level) (indent level port) (format port "~a(~a);\n" operator-name (string-join (map symbol->string (append arg-temps output-temps)) ", ")) output-temps) (define (emit:struct-ref type exp field stage version port level) (define input-temp (single-temp (emit-glsl exp stage version port level))) (define output-temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a.~a;\n" (type->glsl type) output-temp input-temp field) (list output-temp)) (define (emit:array-ref type array-exp index-exp stage version port level) (define array-temp (single-temp (emit-glsl array-exp stage version port level))) (define index-temp (single-temp (emit-glsl index-exp stage version port level))) (define output-temp (unique-identifier)) (indent level port) (format port "~a ~a = ~a[~a];\n" (type->glsl type) output-temp array-temp index-temp) (list output-temp)) (define (emit:top-level bindings body stage version port level) (for-each (match-lambda (((? top-level-qualifier? qualifier) type-desc name) (format port "~a ~a ~a;\n" qualifier (type-descriptor->glsl type-desc) name)) (('function name ('t (type) ('lambda params body))) (emit:function name type params body stage version port level))) bindings) (display "void main() {\n" port) (emit-glsl body stage version port (+ level 1)) (display "}\n" port)) (define %built-in-output-map '((vertex:position . gl_Position) (vertex:point-size . gl_PointSize) (vertex:clip-distance . gl_ClipDistance) (fragment:depth . gl_FragDepth))) (define (emit:outputs names exps stage version port level) (define (output-name name) (or (assq-ref %built-in-output-map name) name)) (if (and (eq? stage 'fragment) (null? names)) (begin (indent level port) (format port "discard;\n")) (for-each (lambda (name exp) (match (emit-glsl exp stage version port level) ((temp) (indent level port) (format port "~a = ~a;\n" (output-name name) temp)))) names exps)) '(#f)) (define* (emit-glsl exp stage version port #:optional (level 0)) (match exp (('t _ (? exact-integer? n)) (emit:int n stage version port level)) (('t _ (? float? n)) (emit:float n stage version port level)) (('t _ (? boolean? b)) (emit:boolean b stage version port level)) (('t _ (? symbol? var)) (list var)) (('t _ ('if predicate consequent alternate)) (emit:if predicate consequent alternate stage version port level)) (('t _ ('values exps ...)) (emit:values exps stage version port level)) (('t types ('let ((names exps) ...) body)) (emit:let types names exps body stage version port level)) (('t (type) ('primcall (? binary-operator? op) a b)) (emit:binary-operator type op a b stage version port level)) (('t (type) ('primcall (? unary-operator? op) a)) (emit:unary-operator type op a stage version port level)) (('t (type) ('primcall op args ...)) (emit:primcall type op args stage version port level)) (('t types ('call operator args ...)) (emit:call types operator args stage version port level)) (('t (type) ('struct-ref exp field)) (emit:struct-ref type exp field stage version port level)) (('t (type) ('array-ref array-exp index-exp)) (emit:array-ref type array-exp index-exp stage version port level)) (('t _ ('outputs (names exps) ...)) (emit:outputs names exps stage version port level)) (('t _ ('top-level (bindings ...) body)) (emit:top-level bindings body stage version port level)))) ;;; ;;; Compiler front-end ;;; ;; Combine all of the compiler passes on a user provided program and ;; emit GLSL code if the program is valid. (define-record-type (make-seagull-global qualifier type-descriptor name) seagull-global? (qualifier seagull-global-qualifier) (type-descriptor seagull-global-type-descriptor) (name seagull-global-name)) (define-record-type (%make-seagull-module stage inputs outputs uniforms source compiled global-map max-id) seagull-module? (stage seagull-module-stage) (inputs seagull-module-inputs) (outputs seagull-module-outputs) (uniforms seagull-module-uniforms) (source seagull-module-source) (compiled seagull-module-compiled) ;; Original name -> alpha converted name mapping for inputs, ;; outputs, and uniforms. (global-map seagull-module-global-map) (max-id seagull-module-max-id)) (define* (make-seagull-module #:key stage inputs outputs uniforms source compiled global-map max-id) (%make-seagull-module stage inputs outputs uniforms source compiled global-map max-id)) (define (seagull-module-vertex? module) (eq? (seagull-module-stage module) 'vertex)) (define (seagull-module-fragment? module) (eq? (seagull-module-stage module) 'fragment)) (define (group-by-qualifier specs) (let loop ((specs specs) (inputs '()) (outputs '()) (uniforms '())) (match specs (() `((inputs . ,(reverse inputs)) (outputs . ,(reverse outputs)) (uniforms . ,(reverse uniforms)))) ((spec . rest) (match spec (('in type-desc name) (loop rest (cons spec inputs) outputs uniforms)) (('out type-desc name) (loop rest inputs (cons spec outputs) uniforms)) (('uniform type-desc name) (loop rest inputs outputs (cons spec uniforms)))))))) (define* (compile-seagull #:key stage source (inputs '()) (outputs '()) (uniforms '())) (unless (memq stage '(vertex fragment)) (error "invalid shader stage" stage)) (parameterize ((unique-identifier-counter 0) (unique-variable-type-counter 0)) (let ((source* `(top-level ,(append inputs outputs uniforms) ,source))) (define-values (expanded global-map) (expand source* stage (top-level-env))) (let* ((propagated (propagate-constants expanded (empty-env))) (hoisted (hoist-functions* propagated)) (inferred (infer-types hoisted stage)) (resolved (resolve-overloads inferred stage))) (values resolved global-map (unique-identifier-counter)))))) (define (specs->globals specs) (map (match-lambda ((qualifier type-desc name) (make-seagull-global qualifier type-desc name))) specs)) ;; Using syntax-case allows us to compile shaders to their fully typed ;; intermediate form at compile time, leaving only GLSL emission for ;; runtime. (define-syntax define-shader-stage (lambda (x) (syntax-case x () ((_ name stage ((qualifier type var) ...) source) (let* ((globals (group-by-qualifier (syntax->datum #'((qualifier type var) ...)))) (inputs (assq-ref globals 'inputs)) (outputs (assq-ref globals 'outputs)) (uniforms (assq-ref globals 'uniforms))) (define-values (compiled global-map max-id) (compile-seagull #:stage (syntax->datum #'stage) #:source (syntax->datum #'source) #:inputs inputs #:outputs outputs #:uniforms uniforms)) (with-syntax ((inputs (datum->syntax x inputs)) (outputs (datum->syntax x outputs)) (uniforms (datum->syntax x uniforms)) (compiled (datum->syntax x compiled)) (global-map (datum->syntax x global-map)) (max-id (datum->syntax x max-id))) #'(define name (make-seagull-module #:stage 'stage #:inputs (specs->globals 'inputs) #:outputs (specs->globals 'outputs) #:uniforms (specs->globals 'uniforms) #:source 'source #:compiled 'compiled #:global-map 'global-map #:max-id max-id)))))))) (define-syntax-rule (define-vertex-shader name specs source) (define-shader-stage name vertex specs source)) (define-syntax-rule (define-fragment-shader name specs source) (define-shader-stage name fragment specs source)) (define (vertex-outputs-match-fragment-inputs? vertex fragment) (let ((fragment-inputs (seagull-module-inputs fragment))) (every (lambda (o1) (any (lambda (o2) (and (eq? (seagull-global-name o1) (seagull-global-name o2)) (equal? (seagull-global-type-descriptor o1) (seagull-global-type-descriptor o2)))) fragment-inputs)) (seagull-module-outputs vertex)))) (define (uniforms-compatible? vertex fragment) (let ((fragment-uniforms (seagull-module-uniforms fragment))) (every (lambda (u1) (every (lambda (u2) (if (eq? (seagull-global-name u1) (seagull-global-name u2)) (equal? (seagull-global-type-descriptor u1) (seagull-global-type-descriptor u2)) #t)) fragment-uniforms)) (seagull-module-outputs vertex)))) (define (rewrite-variables exp subs) (match exp ((? symbol?) (or (assq-ref subs exp) exp)) (() '()) ((exp* . rest) (cons (rewrite-variables exp* subs) (rewrite-variables rest subs))) (_ exp))) (define (link-vertex-outputs-with-fragment-inputs vertex fragment) (define (map-globals specs global-map) (map (lambda (global) (let ((name (seagull-global-name global))) (cons name (assq-ref global-map name)))) specs)) (define (alpha-rename name-map) (map (match-lambda ((original-name . alpha-name) (cons alpha-name (unique-identifier)))) name-map)) (define (remap specs global-map alpha-map) (map (lambda (global) (let ((name (seagull-global-name global))) (cons (assq-ref alpha-map (assq-ref global-map name)) name))) specs)) (let* ((vertex-global-map (seagull-module-global-map vertex)) ;; Create a Scheme name -> alpha-converted GLSL name mapping ;; for vertex outputs. (vertex-output-map (map-globals (seagull-module-outputs vertex) vertex-global-map)) ;; Create a Scheme name -> alpha-converted GLSL name mapping ;; for vertex uniforms. (vertex-uniform-map (map-globals (seagull-module-uniforms vertex) vertex-global-map)) ;; Give new GLSL names to the vertex outputs and uniforms ;; that are unique to both the vertex and fragment shaders. ;; The vertex output names are changed so that the fragment ;; input names can be changed to match. The vertex uniform ;; names are changed so that the names do not clash with ;; fragment globals. (vertex-output-alpha-map (alpha-rename vertex-output-map)) (vertex-uniform-alpha-map (alpha-rename vertex-uniform-map)) (fragment-global-map (seagull-module-global-map fragment)) ;; Create a Scheme name -> alpha-converted GLSL name mapping ;; for fragment inputs. (fragment-input-map (map-globals (seagull-module-inputs fragment) fragment-global-map)) ;; Create a Scheme name -> alpha-converted GLSL name mapping ;; for fragment uniforms. (fragment-uniform-map (map-globals (seagull-module-uniforms fragment) fragment-global-map)) ;; Give new names to the fragment uniforms so that the names ;; do not clash with vertex globals and also that any ;; uniforms in the vertex shader have the *same* name in the ;; fragment shader. (fragment-uniform-alpha-map (map (match-lambda ((original-name . alpha-name) (cons alpha-name (or (assq-ref vertex-uniform-alpha-map (assq-ref vertex-uniform-map original-name)) (unique-identifier))))) fragment-uniform-map)) ;; This one is a little messy but what's happening is that ;; the GLSL name for each fragment output is mapped to the ;; respective renamed input. Vertex shader output names must ;; match fragment shader input names. (fragment-input-alpha-map (append (map (lambda (input) (let ((name (seagull-global-name input))) (cons (assq-ref fragment-global-map name) (assq-ref vertex-output-alpha-map (assq-ref vertex-global-map name))))) (seagull-module-inputs fragment))))) ;; Rewrite the intermediate compiled forms of both shader stages ;; to replace global variable names as needed. (values (rewrite-variables (seagull-module-compiled vertex) (append vertex-uniform-alpha-map vertex-output-alpha-map)) (rewrite-variables (seagull-module-compiled fragment) (append fragment-uniform-alpha-map fragment-input-alpha-map)) ;; Generate a list of alpha-converted GLSL name -> Scheme ;; name mappings. This will be given to the OpenGL shader ;; constructor to map the human readable uniform names to ;; the names they've been given by the compiler. (append (remap (seagull-module-uniforms vertex) vertex-global-map vertex-uniform-alpha-map) (remap (seagull-module-uniforms fragment) fragment-global-map fragment-uniform-alpha-map))))) (define (seagull-module-uniform-map module) (let ((global-map (seagull-module-global-map module))) (map (match-lambda ((_ _ name) (cons (assq-ref global-map name) name))) (seagull-module-uniforms module)))) (define* (link-seagull-modules vertex fragment #:key (version '330)) (unless (seagull-module-vertex? vertex) (error "not a vertex shader" vertex)) (unless (seagull-module-fragment? fragment) (error "not a fragment shader" fragment)) (parameterize ((unique-identifier-counter (max (seagull-module-max-id vertex) (seagull-module-max-id fragment)))) (unless (vertex-outputs-match-fragment-inputs? vertex fragment) (error "vertex outputs do not match fragment inputs")) (unless (uniforms-compatible? vertex fragment) (error "vertex uniforms clash with fragment uniforms")) (define-values (vertex* fragment* uniform-map) (link-vertex-outputs-with-fragment-inputs vertex fragment)) (define vertex-glsl (call-with-output-string (lambda (port) (emit-glsl vertex* 'fragment version port)))) (define fragment-glsl (call-with-output-string (lambda (port) (emit-glsl fragment* 'fragment version port)))) (display vertex-glsl) (newline) (display fragment-glsl) (newline) (values vertex-glsl fragment-glsl uniform-map))) (define (compile-shader vertex fragment) (let-values (((glsl:vertex glsl:fragment uniform-map) (link-seagull-modules vertex fragment))) (call-with-input-string glsl:vertex (lambda (vertex-port) (call-with-input-string glsl:fragment (lambda (fragment-port) (make-shader vertex-port fragment-port #:uniform-map uniform-map)))))))