summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Thompson <dthompson2@worcester.edu>2023-02-03 22:06:13 -0500
committerDavid Thompson <dthompson2@worcester.edu>2023-06-08 08:14:41 -0400
commitcbc4c95a0820934600a26329ee58cc5685060dcb (patch)
tree1514113686cf75acd132d352ab69671f9c2a0b2b
parentc972be910fdfcc50084319562e30a468ab55a211 (diff)
Qualified types that mostly work.
-rw-r--r--chickadee/graphics/seagull.scm1325
1 files changed, 693 insertions, 632 deletions
diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm
index 078e4aa..1ae82bb 100644
--- a/chickadee/graphics/seagull.scm
+++ b/chickadee/graphics/seagull.scm
@@ -18,18 +18,25 @@
;;; Commentary:
;;
;; The Seagull shading language is a purely functional, statically
-;; typed, Scheme-like language that can be compiled to GLSL code.
+;; typed, Scheme-like language that can be compiled to GLSL code. The
+;; reality of how GPUs work imposes some significant language
+;; restrictions, but they are restrictions anyone who writes shader
+;; code is already used to.
;;
-;; Notable features and restrictions:
+;; Features:
;; - Purely functional
;; - Vertex and fragment shader output
;; - Targets multiple GLSL versions
;; - Type inference
;; - Lexical scoping
-;; - First-order functions
-;; - Nested functions (but no closures!)
+;; - Nested functions
;; - Multiple return values
;;
+;; Limitations:
+;; - No first-class functions
+;; - No closures
+;; - No recursion
+;;
;;; Code:
(define-module (chickadee graphics seagull)
#:use-module (ice-9 exceptions)
@@ -47,6 +54,8 @@
;; TODO:
;; - Loops
;; - Shader stage linking
+;; - Transform for-all functions into overloads
+;; - User functions that use overloaded functions need to be overloaded themselves
;;;
@@ -763,14 +772,14 @@
;;;
-;;; Typed expressions
+;;; Type inference
;;;
-;; Typed expressions combine a Seagull expression with the data types
-;; that expression returns. This section includes all the kinds of
-;; types an expression may have and various transformation procedures
-;; to manipulate them. This will be used below to convert the mostly
-;; untyped Seagull program into a fully typed one.
+;; Walk the expression tree of a type annotated program and solve for
+;; all of the type variables using a variant of the Hindley-Milner
+;; type inference algorithm. GLSL is a statically typed language, but
+;; thanks to type inference the user doesn't have to specify any types
+;; expect for shader inputs, outputs, and uniforms.
;; Primitive types:
(define (primitive-type name)
@@ -825,6 +834,9 @@
(define (fresh-type-variable)
(type-variable (unique-type-variable-name)))
+(define (fresh-type-variables-for-list lst)
+ (map (lambda (_x) (fresh-type-variable)) lst))
+
(define (type-variable? obj)
(match obj
(('tvar _) #t)
@@ -847,37 +859,43 @@
(match type
(('-> _ returns) returns)))
+;; Qualified types:
+;; (define (qualified-type type predicate)
+;; `(qualified ,type ,predicate))
+
+;; (define (qualified-type? type)
+;; (match type
+;; (('qualified _ _) #t)
+;; (_ #f)))
+
+;; (define (qualified-type-ref type)
+;; (match type
+;; (('qualified type* _) type*)))
+
+;; (define (qualified-type-predicate type)
+;; (match type
+;; (('qualified _ predicate) predicate)))
+
;; For all types:
-(define (for-all-type quantifiers type)
- `(for-all ,quantifiers ,type))
+(define (for-all-type quantifiers type predicate)
+ `(for-all ,quantifiers ,type ,predicate))
(define (for-all-type? obj)
(match obj
- (('for-all _ _) #t)
+ (('for-all _ _ _) #t)
(_ #f)))
(define (for-all-type-quantifiers type)
(match type
- (('for-all q _) q)))
+ (('for-all q _ _) q)))
(define (for-all-type-ref type)
(match type
- (('for-all _ t) t)))
-
-;; Overload types:
-(define (overload-type . types)
- (match types
- (((? function-type?) ...)
- `(overload ,types))))
-
-(define (overload-type? obj)
- (match obj
- (('overload _) #t)
- (_ #f)))
+ (('for-all _ t _) t)))
-(define (overload-type-ref type)
+(define (for-all-type-predicate type)
(match type
- (('overload types) types)))
+ (('for-all _ _ p) p)))
;; Outputs type:
(define outputs-type '(outputs))
@@ -891,7 +909,8 @@
(or (primitive-type? obj)
(type-variable? obj)
(function-type? obj)
- (overload-type? obj)))
+ ;; (qualified-type? obj)
+ (outputs-type? obj)))
(define (apply-substitution-to-type type from to)
(cond
@@ -908,11 +927,13 @@
(map (lambda (return-type)
(apply-substitution-to-type return-type from to))
(function-type-returns type))))
- ((overload-type? type)
- (apply overload-type
- (map (lambda (type)
- (apply-substitution-to-type type from to))
- (overload-type-ref type))))
+ ;; ((qualified-type? type)
+ ;; (qualified-type (apply-substitution-to-type
+ ;; (qualified-type-ref type) from to)
+ ;; (apply-substitution-to-predicate
+ ;; (qualified-type-predicate type) from to)))
+ ((for-all-type? type)
+ type)
(else (error "invalid type" type))))
(define (apply-substitutions-to-type type subs)
@@ -926,6 +947,53 @@
(apply-substitutions-to-type type subs))
types))
+(define (apply-substitution-to-env env from to)
+ (env-fold (lambda (name types env*)
+ (extend-env name
+ (map (lambda (type)
+ (apply-substitution-to-type type from to))
+ types)
+ env*))
+ (empty-env)
+ env))
+
+(define (apply-substitutions-to-env env subs)
+ (env-fold (lambda (from to env*)
+ (apply-substitution-to-env env* from to))
+ env
+ subs))
+
+(define (apply-substitutions-to-texp t subs)
+ (texp (apply-substitutions-to-types (texp-types t) subs)
+ (texp-exp t)))
+
+(define (apply-substitution-to-predicate pred from to)
+ (match pred
+ (#t #t)
+ (('= a b)
+ `(= ,(apply-substitution-to-type a from to)
+ ,(apply-substitution-to-type b from to)))
+ (('and preds ...)
+ `(and ,@(map (lambda (pred)
+ (apply-substitution-to-predicate pred from to))
+ preds)))
+ (('or preds ...)
+ `(or ,@(map (lambda (pred)
+ (apply-substitution-to-predicate pred from to))
+ preds)))
+ (('when test consequent)
+ `(when ,(apply-substitution-to-predicate test from to)
+ ,(apply-substitution-to-predicate consequent from to)))
+ (('substitute a b)
+ `(substitute ,(apply-substitution-to-type a from to)
+ ,(apply-substitution-to-type b from to)))))
+
+(define (apply-substitutions-to-predicate pred subs)
+ (env-fold (lambda (from to pred*)
+ (apply-substitution-to-predicate pred* from to))
+ pred
+ subs))
+
;; Typed expressions:
(define (texp types exp)
`(t ,types ,exp))
@@ -948,184 +1016,217 @@
((type) type)
(_ (error "expected only 1 type" texp))))
-
-;;;
-;;; Type annotation
-;;;
+(define &seagull-type-error
+ (make-exception-type '&seagull-type-error &error '()))
+
+(define make-seagull-type-error
+ (record-constructor &seagull-type-error))
+
+(define (seagull-type-error msg args origin)
+ (raise-exception
+ (make-exception
+ (make-seagull-type-error)
+ (make-exception-with-origin origin)
+ (make-exception-with-message
+ (format #f "seagull type error: ~a" msg))
+ (make-exception-with-irritants args))))
+
+(define (occurs? a b)
+ (cond
+ ((and (type-variable? a) (type-variable? b))
+ (eq? a b))
+ ((and (type-variable? a) (function-type? b))
+ (or (occurs? a (function-type-parameters b))
+ (occurs? a (function-type-returns b))))
+ ((and (type? a) (list? b))
+ (any (lambda (b*) (occurs? a b*)) b))
+ (else #f)))
+
+(define (compose-substitutions a b)
+ (define b*
+ (map (match-lambda
+ ((from . to)
+ (cons from (apply-substitutions-to-type to a))))
+ b))
+ (define a*
+ (filter-map (match-lambda
+ ((from . to)
+ (if (assq-ref b* from)
+ #f
+ (cons from to))))
+ a))
+ (append a* b*))
-;; Convert untyped Seagull expressions into typed expressions with
-;; type variables representing all unknown types. This annotated
-;; version of a Seagull program can then be passed to the type
-;; inference algorithm to solve for all of the variables.
-
-(define add/sub-type
- (list (overload-type
- (function-type (list int-type int-type)
- (list int-type))
- (function-type (list float-type float-type)
- (list float-type))
- (function-type (list vec2-type vec2-type)
- (list vec2-type))
- (function-type (list vec3-type vec3-type)
- (list vec3-type))
- (function-type (list vec4-type vec4-type)
- (list vec4-type))
- (function-type (list mat3-type mat3-type)
- (list mat3-type))
- (function-type (list mat4-type mat4-type)
- (list mat4-type)))))
-
-(define mul-type
- (list (overload-type
- (function-type (list int-type int-type)
- (list int-type))
- (function-type (list float-type float-type)
- (list float-type))
- (function-type (list vec2-type vec2-type)
- (list vec2-type))
- (function-type (list vec2-type float-type)
- (list vec2-type))
- (function-type (list float-type vec2-type)
- (list vec2-type))
- (function-type (list vec3-type vec3-type)
- (list vec3-type))
- (function-type (list vec3-type float-type)
- (list vec3-type))
- (function-type (list float-type vec3-type)
- (list vec3-type))
- (function-type (list vec4-type vec4-type)
- (list vec4-type))
- (function-type (list vec4-type float-type)
- (list vec4-type))
- (function-type (list float-type vec4-type)
- (list vec4-type))
- (function-type (list mat3-type mat3-type)
- (list mat3-type))
- (function-type (list mat3-type vec3-type)
- (list mat3-type))
- (function-type (list vec3-type mat3-type)
- (list mat3-type))
- (function-type (list mat4-type mat4-type)
- (list mat4-type))
- (function-type (list mat4-type vec4-type)
- (list vec4-type))
- (function-type (list vec4-type mat4-type)
- (list vec4-type)))))
-
-(define div-type
- (list (overload-type
- (function-type (list int-type int-type)
- (list int-type))
- (function-type (list float-type float-type)
- (list float-type))
- (function-type (list vec2-type vec2-type)
- (list vec2-type))
- (function-type (list vec2-type float-type)
- (list vec2-type))
- (function-type (list vec3-type vec3-type)
- (list vec3-type))
- (function-type (list vec3-type float-type)
- (list vec3-type))
- (function-type (list vec4-type vec4-type)
- (list vec4-type))
- (function-type (list vec4-type float-type)
- (list vec4-type))
- (function-type (list mat3-type float-type)
- (list mat3-type))
- (function-type (list mat4-type float-type)
- (list mat4-type)))))
-
-(define comparison-type
- (list (overload-type
- (function-type (list int-type int-type)
- (list bool-type))
- (function-type (list float-type float-type)
- (list bool-type)))))
-
-(define make-vec2-type
- (list (function-type (list float-type float-type)
- (list vec2-type))))
-
-(define make-vec3-type
- (list (overload-type
- (function-type (list float-type float-type float-type)
- (list vec3-type))
- (function-type (list vec2-type float-type)
- (list vec3-type))
- (function-type (list float-type vec2-type)
- (list vec3-type)))))
-
-(define make-vec4-type
- (list (overload-type
- (function-type (list float-type float-type float-type float-type)
- (list vec4-type))
- (function-type (list vec2-type float-type float-type)
- (list vec4-type))
- (function-type (list float-type vec2-type float-type)
- (list vec4-type))
- (function-type (list float-type float-type vec2-type)
- (list vec4-type))
- (function-type (list vec3-type float-type)
- (list vec4-type))
- (function-type (list float-type vec3-type)
- (list vec4-type)))))
-
-(define abs-type
- (list (overload-type
- (function-type (list int-type) (list int-type))
- (function-type (list float-type) (list float-type)))))
-
-(define sqrt-type
- (list (overload-type
- (function-type (list int-type) (list int-type))
- (function-type (list float-type) (list float-type)))))
-
-(define min/max-type
- (list (overload-type
- (function-type (list int-type int-type) (list int-type))
- (function-type (list float-type float-type) (list float-type)))))
-
-(define trigonometry-type
- (list (function-type (list float-type) (list float-type))))
-
-(define clamp/mix-type
- (list (overload-type
- (function-type (list int-type int-type int-type)
- (list int-type))
- (function-type (list float-type float-type float-type)
- (list float-type)))))
+;; (define add/sub-type
+;; (list (overload-type
+;; (function-type (list int-type int-type)
+;; (list int-type))
+;; (function-type (list float-type float-type)
+;; (list float-type))
+;; (function-type (list vec2-type vec2-type)
+;; (list vec2-type))
+;; (function-type (list vec3-type vec3-type)
+;; (list vec3-type))
+;; (function-type (list vec4-type vec4-type)
+;; (list vec4-type))
+;; (function-type (list mat3-type mat3-type)
+;; (list mat3-type))
+;; (function-type (list mat4-type mat4-type)
+;; (list mat4-type)))))
+
+;; (define mul-type
+;; (list (overload-type
+;; (function-type (list int-type int-type)
+;; (list int-type))
+;; (function-type (list float-type float-type)
+;; (list float-type))
+;; (function-type (list vec2-type vec2-type)
+;; (list vec2-type))
+;; (function-type (list vec2-type float-type)
+;; (list vec2-type))
+;; (function-type (list float-type vec2-type)
+;; (list vec2-type))
+;; (function-type (list vec3-type vec3-type)
+;; (list vec3-type))
+;; (function-type (list vec3-type float-type)
+;; (list vec3-type))
+;; (function-type (list float-type vec3-type)
+;; (list vec3-type))
+;; (function-type (list vec4-type vec4-type)
+;; (list vec4-type))
+;; (function-type (list vec4-type float-type)
+;; (list vec4-type))
+;; (function-type (list float-type vec4-type)
+;; (list vec4-type))
+;; (function-type (list mat3-type mat3-type)
+;; (list mat3-type))
+;; (function-type (list mat3-type vec3-type)
+;; (list mat3-type))
+;; (function-type (list vec3-type mat3-type)
+;; (list mat3-type))
+;; (function-type (list mat4-type mat4-type)
+;; (list mat4-type))
+;; (function-type (list mat4-type vec4-type)
+;; (list vec4-type))
+;; (function-type (list vec4-type mat4-type)
+;; (list vec4-type)))))
+
+;; (define div-type
+;; (list (overload-type
+;; (function-type (list int-type int-type)
+;; (list int-type))
+;; (function-type (list float-type float-type)
+;; (list float-type))
+;; (function-type (list vec2-type vec2-type)
+;; (list vec2-type))
+;; (function-type (list vec2-type float-type)
+;; (list vec2-type))
+;; (function-type (list vec3-type vec3-type)
+;; (list vec3-type))
+;; (function-type (list vec3-type float-type)
+;; (list vec3-type))
+;; (function-type (list vec4-type vec4-type)
+;; (list vec4-type))
+;; (function-type (list vec4-type float-type)
+;; (list vec4-type))
+;; (function-type (list mat3-type float-type)
+;; (list mat3-type))
+;; (function-type (list mat4-type float-type)
+;; (list mat4-type)))))
+
+;; (define comparison-type
+;; (list (overload-type
+;; (function-type (list int-type int-type)
+;; (list bool-type))
+;; (function-type (list float-type float-type)
+;; (list bool-type)))))
+
+;; (define make-vec2-type
+;; (list (function-type (list float-type float-type)
+;; (list vec2-type))))
+
+;; (define make-vec3-type
+;; (list (overload-type
+;; (function-type (list float-type float-type float-type)
+;; (list vec3-type))
+;; (function-type (list vec2-type float-type)
+;; (list vec3-type))
+;; (function-type (list float-type vec2-type)
+;; (list vec3-type)))))
+
+;; (define make-vec4-type
+;; (list (overload-type
+;; (function-type (list float-type float-type float-type float-type)
+;; (list vec4-type))
+;; (function-type (list vec2-type float-type float-type)
+;; (list vec4-type))
+;; (function-type (list float-type vec2-type float-type)
+;; (list vec4-type))
+;; (function-type (list float-type float-type vec2-type)
+;; (list vec4-type))
+;; (function-type (list vec3-type float-type)
+;; (list vec4-type))
+;; (function-type (list float-type vec3-type)
+;; (list vec4-type)))))
+
+;; (define abs-type
+;; (list (overload-type
+;; (function-type (list int-type) (list int-type))
+;; (function-type (list float-type) (list float-type)))))
+
+;; (define sqrt-type
+;; (list (overload-type
+;; (function-type (list int-type) (list int-type))
+;; (function-type (list float-type) (list float-type)))))
+
+;; (define min/max-type
+;; (list (overload-type
+;; (function-type (list int-type int-type) (list int-type))
+;; (function-type (list float-type float-type) (list float-type)))))
+
+;; (define trigonometry-type
+;; (list (function-type (list float-type) (list float-type))))
+
+;; (define clamp/mix-type
+;; (list (overload-type
+;; (function-type (list int-type int-type int-type)
+;; (list int-type))
+;; (function-type (list float-type float-type float-type)
+;; (list float-type)))))
(define (top-level-type-env stage)
- `((+ . ,add/sub-type)
- (- . ,add/sub-type)
- (* . ,mul-type)
- (/ . ,div-type)
- (= . ,comparison-type)
- (< . ,comparison-type)
- (<= . ,comparison-type)
- (> . ,comparison-type)
- (>= . ,comparison-type)
- (vec2 . ,make-vec2-type)
- (vec3 . ,make-vec3-type)
- (vec4 . ,make-vec4-type)
- (not ,(function-type (list bool-type) (list bool-type)))
- (int->float ,(function-type (list int-type) (list float-type)))
- (float->int ,(function-type (list float-type) (list int-type)))
- (abs . ,abs-type)
- (sqrt . ,sqrt-type)
- (min . ,min/max-type)
- (max . ,min/max-type)
- (sin . ,trigonometry-type)
- (cos . ,trigonometry-type)
- (tan . ,trigonometry-type)
- (clamp . ,clamp/mix-type)
- (mix . ,clamp/mix-type)
- ,@(case stage
- ((vertex)
- `((gl-position ,vec4-type)))
- ((fragment)
- `((texture-2d ,(function-type (list sampler-2d-type vec2-type)
- (list vec4-type))))))))
+ '()
+ ;; `((+ . ,add/sub-type)
+ ;; (- . ,add/sub-type)
+ ;; (* . ,mul-type)
+ ;; (/ . ,div-type)
+ ;; (= . ,comparison-type)
+ ;; (< . ,comparison-type)
+ ;; (<= . ,comparison-type)
+ ;; (> . ,comparison-type)
+ ;; (>= . ,comparison-type)
+ ;; (vec2 . ,make-vec2-type)
+ ;; (vec3 . ,make-vec3-type)
+ ;; (vec4 . ,make-vec4-type)
+ ;; (not ,(function-type (list bool-type) (list bool-type)))
+ ;; (int->float ,(function-type (list int-type) (list float-type)))
+ ;; (float->int ,(function-type (list float-type) (list int-type)))
+ ;; (abs . ,abs-type)
+ ;; (sqrt . ,sqrt-type)
+ ;; (min . ,min/max-type)
+ ;; (max . ,min/max-type)
+ ;; (sin . ,trigonometry-type)
+ ;; (cos . ,trigonometry-type)
+ ;; (tan . ,trigonometry-type)
+ ;; (clamp . ,clamp/mix-type)
+ ;; (mix . ,clamp/mix-type)
+ ;; ,@(case stage
+ ;; ((vertex)
+ ;; `((gl-position ,vec4-type)))
+ ;; ((fragment)
+ ;; `((texture-2d ,(function-type (list sampler-2d-type vec2-type)
+ ;; (list vec4-type)))))))
+ )
(define (lookup-type name env)
(let ((type (lookup name env)))
@@ -1175,13 +1276,40 @@
'()
env)))
-(define (generalize type env)
+(define (free-variables-in-predicate pred)
+ (match pred
+ (#t '())
+ (('= a b)
+ (append (free-variables-in-type a)
+ (free-variables-in-type b)))
+ (('and preds ...)
+ (append-map (lambda (pred)
+ (free-variables-in-predicate pred))
+ preds))
+ (('or preds ...)
+ (append-map (lambda (pred)
+ (free-variables-in-predicate pred))
+ preds))
+ (('when test consequent)
+ (append (free-variables-in-predicate test)
+ (free-variables-in-predicate consequent)))
+ (('substitute a b)
+ (append (free-variables-in-type a)
+ (free-variables-in-type b)))))
+
+;; Quantified variables:
+;; - Unused parameters
+;; - Parameters that appear free in the return type
+;; - Parameters that are used in overloaded primcalls
+(define (generalize type pred env)
(if (function-type? type)
- (match (difference (free-variables-in-type type)
+ (match (difference (delete-duplicates
+ (append (free-variables-in-type type)
+ (free-variables-in-predicate pred)))
(free-variables-in-env env))
(() type)
((quantifiers ...)
- (for-all-type quantifiers type)))
+ (for-all-type quantifiers type pred)))
type))
(define (instantiate for-all)
@@ -1190,451 +1318,384 @@
(extend-env var (fresh-type-variable) env))
(empty-env)
(for-all-type-quantifiers for-all)))
- (apply-substitutions-to-type (for-all-type-ref for-all) subs))
-
-(define (fresh-type-variables-for-list lst)
- (map (lambda (_x) (fresh-type-variable)) lst))
-
-(define (annotate:list exps env)
- (map (lambda (exp) (annotate-exp exp env)) exps))
-
-(define (annotate:if predicate consequent alternate env)
- (define consequent* (annotate-exp consequent env))
- (texp (texp-types consequent*)
- `(if ,(annotate-exp predicate env)
- ,consequent*
- ,(annotate-exp alternate env))))
-
-(define (annotate:let names exps body env)
- (define exps* (annotate:list exps env))
- (define exp-types (map texp-types exps*))
- (define env* (fold extend-env env names exp-types))
- (define body* (annotate-exp body env*))
- (texp (texp-types body*)
- `(let ,(map list names exps*) ,body*)))
-
-(define (annotate:lambda params body env)
- ;; Each function parameter gets a fresh type variable.
- (define param-types (fresh-type-variables-for-list params))
- ;; The type environment is extended with the function parameters.
- (define env*
- (fold (lambda (param type env*)
- (extend-env param (list type) env*))
- env params param-types))
- ;; The body is annotated in the new environment.
- (define body* (annotate-exp body env*))
- (texp (list (function-type param-types (texp-types body*)))
- `(lambda ,params ,body*)))
-
-(define (annotate:values exps env)
- (define exps* (annotate:list exps env))
- (texp (map single-type exps*)
- `(values ,@exps*)))
-
-(define (annotate:primitive-call operator args env)
- ;; The type signature of primitive functions can be looked up
- ;; directly in the environment.
- (define operator-types (lookup-type operator env))
- (define operator* (texp operator-types operator))
- (define args* (annotate:list args env))
- (texp (fresh-type-variables-for-list (list operator*))
- `(primcall ,operator* ,@args*)))
-
-(define (annotate:call operator args env)
- (define operator* (annotate-exp operator env))
- (define args* (annotate:list args env))
- (texp (fresh-type-variables-for-list
- (function-type-returns
- (single-type operator*)))
- `(call ,operator* ,@args*)))
-
-(define* (annotate:top-level bindings body env #:optional (result '()))
- (match bindings
- (()
- (let ((body* (annotate-exp body env)))
- (texp (texp-types body*) `(top-level ,(reverse result) ,body*))))
- ((('function name exp) . rest)
- (define exp*
- (let ((x (annotate-exp exp env)))
- ;; Function types must be generalized so that functions like
- ;; (lambda (x) x) can truly be applied to any type.
- (texp (list (generalize (single-type x) env))
- (texp-exp x))))
- (define env* (extend-env name (texp-types exp*) env))
- (define result* (cons `(function ,name ,exp*) result))
- (annotate:top-level rest body env* result*))
- ((((? top-level-qualifier? qualifier) type-name name) . rest)
- (define env* (extend-env name (list (type-name->type type-name)) env))
- (define result* (cons (list qualifier type-name name) result))
- (annotate:top-level rest body env* result*))))
-
-(define (annotate:outputs names exps env)
- (texp (list outputs-type)
- `(outputs
- ,@(map (lambda (name exp)
- (list (texp (lookup name env) name)
- (annotate-exp exp env)))
- names exps))))
-
-(define (annotate-exp exp env)
- (match exp
- ((? exact-integer?)
- (texp (list int-type) exp))
- ((? float?)
- (texp (list float-type) exp))
- ((? boolean?)
- (texp (list bool-type) exp))
- (('var var _)
- (texp (lookup-type var env) exp))
- (('if predicate consequent alternate)
- (annotate:if predicate consequent alternate env))
- (('let ((names exps) ...) body)
- (annotate:let names exps body env))
- (('lambda (params ...) body)
- (annotate:lambda params body env))
- (('values exps ...)
- (annotate:values exps env))
- (('primcall operator args ...)
- (annotate:primitive-call operator args env))
- (('call operator args ...)
- (annotate:call operator args env))
- (('outputs (names exps) ...)
- (annotate:outputs names exps env))
- (('top-level bindings body)
- (annotate:top-level bindings body env))))
-
-(define (annotate-exp* exp stage)
- (annotate-exp exp (top-level-type-env stage)))
-
-
-;;;
-;;; Type inference
-;;;
-
-;; Walk the expression tree of a type annotated program and solve for
-;; all of the type variables using a variant of the Hindley-Milner
-;; type inference algorithm. GLSL is a statically typed language, but
-;; thanks to type inference the user doesn't have to specify any types
-;; expect for shader inputs, outputs, and uniforms.
-
-(define &seagull-type-error
- (make-exception-type '&seagull-type-error &error '()))
-
-(define make-seagull-type-error
- (record-constructor &seagull-type-error))
-
-(define (seagull-type-error msg args origin)
- (raise-exception
- (make-exception
- (make-seagull-type-error)
- (make-exception-with-origin origin)
- (make-exception-with-message
- (format #f "seagull type error: ~a" msg))
- (make-exception-with-irritants args))))
-
-(define (occurs? a b)
+ (values
+ (apply-substitutions-to-type (for-all-type-ref for-all) subs)
+ (apply-substitutions-to-predicate (for-all-type-predicate for-all) subs)))
+
+(define (maybe-instantiate types)
+ (define types+preds
+ (map (lambda (type)
+ (if (for-all-type? type)
+ (call-with-values (lambda () (instantiate type)) list)
+ (list type #t)))
+ types))
+ (values (map first types+preds)
+ (reduce compose-predicates #t (map second types+preds))))
+
+;; (define (strip-qualifier type)
+;; (if (qualified-type? type)
+;; (qualified-type-ref type)
+;; type))
+
+;; (define (strip-qualifiers types)
+;; (map strip-qualifier types))
+
+;; (define (predicate-for-type type)
+;; (if (qualified-type? type)
+;; (qualified-type-predicate type)
+;; #t))
+
+(define (compose-predicates a b)
(cond
- ((and (type-variable? a) (type-variable? b))
- (eq? a b))
- ((and (type-variable? a) (function-type? b))
- (or (occurs? a (function-type-parameters b))
- (occurs? a (function-type-returns b))))
- ((and (type? a) (list? b))
- (any (lambda (b*) (occurs? a b*)) b))
- (else #f)))
-
-(define (compose-substitutions a b)
- (define b*
- (map (match-lambda
- ((from . to)
- (cons from (apply-substitutions-to-type to a))))
- b))
- (define a*
- (filter-map (match-lambda
- ((from . to)
- (if (assq-ref b* from)
- #f
- (cons from to))))
- a))
- (append a* b*))
-
-(define unify-prompt-tag (make-prompt-tag 'unify))
-
-(define (call-with-unify-rollback thunk handler)
- (call-with-prompt unify-prompt-tag
- thunk
- (lambda (k args)
- (apply handler args))))
-
-(define (unify:fail . args)
- (abort-to-prompt unify-prompt-tag args))
-
-(define (unify:primitives a b success)
+ ((and (eq? a #t) (eq? b #t))
+ #t)
+ ((eq? a #t)
+ b)
+ ((eq? b #t)
+ a)
+ (else
+ `(and ,a ,b))))
+
+(define (compose-predicates* preds)
+ (reduce (lambda (pred memo)
+ (compose-predicates pred memo))
+ #t
+ preds))
+
+(define (compose-predicates-for-types types)
+ (compose-predicates* (map predicate-for-type types)))
+
+;; Produces a simplified predicate and a new set of substitutions for
+;; predicates that have been satisfied and simplified to #t.
+(define (eval-predicate pred)
+ (match pred
+ (#t (values #t '()))
+ ((or ('= (? type-variable?) _)
+ ('= _ (? type-variable?)))
+ (values pred '()))
+ (('= a b)
+ (values (equal? a b) '()))
+ (('or preds ...)
+ (let loop ((preds preds))
+ (match preds
+ (() (values #f '()))
+ ((pred* . rest)
+ (define-values (new-pred subs)
+ (eval-predicate pred*))
+ (match new-pred
+ (#t (values #t subs))
+ (#f (eval-predicate `(or ,@rest)))
+ (_ (values `(or ,new-pred ,@rest) '())))))))
+ (('and preds ...)
+ (let loop ((preds preds))
+ (match preds
+ (() (values #t '()))
+ ((pred* . rest)
+ (define-values (new-pred subs)
+ (eval-predicate pred*))
+ (match new-pred
+ (#t
+ (let ()
+ (define-values (new-pred* subs*)
+ (eval-predicate `(and ,@rest)))
+ (values new-pred* (compose-substitutions subs subs*))))
+ (#f (values #f '()))
+ (_ (values `(and ,new-pred ,@rest) '())))))))
+ (('when test consequent)
+ (define-values (new-test subs)
+ (eval-predicate test))
+ (match new-test
+ (#t
+ (let ()
+ (define-values (new-consequent subs*)
+ (eval-predicate consequent))
+ (values new-consequent (compose-substitutions subs subs*))))
+ (#f (values #f '()))
+ (_ (values `(when ,new-test ,consequent) '()))))
+ (('substitute a b)
+ (values #t (list (cons a b))))))
+
+(define (eval-predicate* pred subs)
+ (define-values (new-pred pred-subs)
+ (eval-predicate
+ (apply-substitutions-to-predicate pred subs)))
+ ;; TODO: Get information about *why* the predicate failed.
+ (unless new-pred
+ (error "predicate failure"))
+ (values new-pred (compose-substitutions subs pred-subs)))
+
+(define (unify:primitives a b)
(if (equal? a b)
- (success '())
- (unify:fail "primitive type mismatch" a b)))
+ '()
+ (error "primitive type mismatch" a b)))
-(define (unify:variable a b success)
+(define (unify:variable a b)
(cond
((eq? a b)
- (success '()))
+ '())
((occurs? a b)
- (unify:fail "type contains reference to itself" a b))
+ (error "type contains reference to itself" a b))
(else
- (success (list (cons a b))))))
-
-(define (unify:functions a b success)
- (unify (function-type-parameters a)
- (function-type-parameters b)
- (lambda (sub0)
- (unify (apply-substitutions-to-types (function-type-returns a)
- sub0)
- (apply-substitutions-to-types (function-type-returns b)
- sub0)
- (lambda (sub1)
- (success (compose-substitutions sub0 sub1)))))))
-
-(define (unify:overload a b success)
- (define (try-unify functions)
- (match functions
- (()
- (unify:fail "no matching overload" a b))
- ((function . rest)
- (call-with-unify-rollback
- (lambda ()
- (unify function b success))
- (lambda _
- (try-unify rest))))))
- (try-unify (overload-type-ref a)))
-
-(define (unify:lists a rest-a b rest-b success)
- (unify a b
- (lambda (sub0)
- (unify (apply-substitutions-to-types rest-a sub0)
- (apply-substitutions-to-types rest-b sub0)
- (lambda (sub1)
- (success (compose-substitutions sub0 sub1)))))))
-
-(define (unify a b success)
+ (list (cons a b)))))
+
+(define (unify:functions a b)
+ (define param-subs
+ (unify (function-type-parameters a)
+ (function-type-parameters b)))
+ (define return-subs
+ (unify (apply-substitutions-to-types (function-type-returns a)
+ param-subs)
+ (apply-substitutions-to-types (function-type-returns b)
+ param-subs)))
+ (compose-substitutions param-subs return-subs))
+
+(define (unify:lists a rest-a b rest-b)
+ (define sub-first (unify a b))
+ (define sub-rest
+ (unify (apply-substitutions-to-types rest-a sub-first)
+ (apply-substitutions-to-types rest-b sub-first)))
+ (compose-substitutions sub-first sub-rest))
+
+(define (unify a b)
(match (list a b)
(((? primitive-type? a) (? primitive-type? b))
- (unify:primitives a b success))
+ (unify:primitives a b))
((or ((? type-variable? a) b)
(b (? type-variable? a)))
- (unify:variable a b success))
+ (unify:variable a b))
(((? function-type? a) (? function-type? b))
- (unify:functions a b success))
- ((a (? overload-type? b))
- (unify:overload b a success))
+ (unify:functions a b))
(((? outputs-type?) (? outputs-type?))
- (success '()))
+ '())
((() ())
- (success '()))
+ '())
(((a rest-a ...) (b rest-b ...))
- (unify:lists a rest-a b rest-b success))
+ (unify:lists a rest-a b rest-b))
(_
- (unify:fail "type mismatch" a b))))
-
-(define (infer:list exps subs success)
- (match exps
- (() (success subs))
- ((exp . rest)
- (infer exp
- subs
- (lambda (sub0)
- (infer:list rest
- sub0
- (lambda (sub1)
- (success
- (compose-substitutions sub0 sub1)))))))))
-
-(define (infer:if types predicate consequent alternate subs success)
- (infer predicate
- subs
- (lambda (sub0)
- (unify (single-type predicate)
- bool-type
- (lambda (sub1)
- (define sub2 (compose-substitutions sub0 sub1))
- (infer consequent
- sub2
- (lambda (sub3)
- (infer alternate
- sub3
- (lambda (sub4)
- (unify (apply-substitutions-to-types
- (texp-types consequent)
- sub4)
- (apply-substitutions-to-types
- (texp-types alternate)
- sub4)
- (lambda (sub5)
- (success
- (compose-substitutions
- sub4 sub5)))))))))))))
-
-(define (infer:lambda type body subs success)
- (define type*
- (if (for-all-type? type)
- (for-all-type-ref type)
- type))
- (infer body
- subs
- (lambda (sub0)
- (unify (apply-substitutions-to-types (texp-types body) sub0)
- (function-type-returns type*)
- (lambda (sub1)
- (success
- (compose-substitutions sub0 sub1)))))))
-
-(define (infer:values types exps subs success)
- (infer:list exps subs success))
-
-(define (infer:call types operator args subs success)
- (infer operator
- subs
- (lambda (sub0)
- (infer:list
- args
- sub0
- (lambda (sub1)
- ;; Check if function call has the proper number of
- ;; arguments.
- (let* ((k (length args))
- (l (length (function-type-parameters
- (single-type operator)))))
- (unless (= k l)
- (seagull-type-error
- (format #f "expected ~a arguments, got ~a" l k)
- '()
- infer:call)))
- (unify (apply-substitutions-to-type
- (function-type (map single-type args)
- types)
- sub1)
- (apply-substitutions-to-type
- (single-type operator)
- sub1)
- (lambda (sub2)
- (success (compose-substitutions sub1 sub2)))))))))
-
-(define (infer:primcall types operator args subs success)
- (infer:list
- args
- subs
- (lambda (sub0)
- (unify (apply-substitutions-to-type
- (function-type (map single-type args)
- types)
- sub0)
- (apply-substitutions-to-type
- (single-type operator)
- sub0)
- (lambda (sub1)
- (success
- (compose-substitutions sub0 sub1)))))))
-
-(define (infer:let types names exps body subs success)
- (infer:list
- exps
- subs
- (lambda (sub0)
- (infer body
- sub0
- (lambda (sub1)
- (unify (apply-substitutions-to-types types sub1)
- (apply-substitutions-to-types (texp-types body) sub1)
- (lambda (sub2)
- (success
- (compose-substitutions sub1 sub2)))))))))
-
-(define (infer:top-level types bindings body subs success)
- (define exps
- (filter-map (match-lambda
- (('function _ exp) exp)
- (_ #f))
- bindings))
- (infer:list
- exps
- subs
- (lambda (sub0)
- (infer body
- sub0
- (lambda (sub1)
- (unify (apply-substitutions-to-types types sub1)
- (apply-substitutions-to-types (texp-types body) sub1)
- (lambda (sub2)
- (success
- (compose-substitutions sub1 sub2)))))))))
-
-(define (infer:outputs names exps subs success)
- (define output-types (map single-type names))
- (infer:list
- exps
- subs
- (lambda (sub0)
- (unify output-types
- (apply-substitutions-to-types (map single-type exps) sub0)
- (lambda (sub1)
- (success
- (compose-substitutions sub0 sub1)))))))
-
-(define (infer exp subs success)
+ (error "type mismatch" a b))))
+
+(define (infer:immediate x)
+ (values (texp (list (cond
+ ((exact-integer? x)
+ int-type)
+ ((float? x)
+ float-type)
+ ((boolean? x)
+ bool-type)))
+ x)
+ '()
+ #t))
+
+(define (infer:variable name env)
+ (define-values (types pred)
+ (maybe-instantiate (lookup-type name env)))
+ (values (texp types name)
+ '()
+ pred))
+
+(define (infer:list exps env)
+ (let loop ((exps exps)
+ (texps '())
+ (subs '())
+ (pred #t))
+ (match exps
+ (()
+ (values (reverse texps) subs pred))
+ ((exp . rest)
+ (define-values (texp subs* pred*)
+ (infer-exp exp env))
+ (define-values (new-pred combined-subs)
+ (eval-predicate* (compose-predicates pred pred*)
+ (compose-substitutions subs subs*)))
+ (loop rest
+ (cons texp texps)
+ combined-subs
+ new-pred)))))
+
+(define (infer:if predicate consequent alternate env)
+ ;; Infer predicate types and unify it with the boolean type.
+ (define-values (predicate-texp predicate-subs predicate-pred)
+ (infer-exp predicate env))
+ (define predicate-unify-subs
+ (unify (texp-types predicate-texp) (list bool-type)))
+ ;; Combine the substitutions and apply them to the environment.
+ (define combined-subs-0
+ (compose-substitutions predicate-subs predicate-unify-subs))
+ (define env0
+ (apply-substitutions-to-env env combined-subs-0))
+ ;; Infer consequent and alternate types and unify them against each
+ ;; other. Each branch of an 'if' should have the same type.
+ (define-values (consequent-texp consequent-subs consequent-pred)
+ (infer-exp consequent env0))
+ (define combined-subs-1
+ (compose-substitutions combined-subs-0 consequent-subs))
+ (define env1
+ (apply-substitutions-to-env env0 consequent-subs))
+ (define-values (alternate-texp alternate-subs alternate-pred)
+ (infer-exp alternate env1))
+ (define combined-subs-2
+ (compose-substitutions combined-subs-1 alternate-subs))
+ ;; Eval combined predicate.
+ (define-values (pred combined-subs-3)
+ (eval-predicate* (compose-predicates predicate-pred
+ (compose-predicates consequent-pred
+ alternate-pred))
+ combined-subs-2))
+ ;; ;; Apply final set of substitutions to the types of both branches.
+ (define consequent-texp*
+ (apply-substitutions-to-texp consequent-texp combined-subs-3))
+ (define alternate-texp*
+ (apply-substitutions-to-texp alternate-texp combined-subs-3))
+ (values (texp (texp-types consequent-texp)
+ `(if ,predicate-texp ,consequent-texp ,alternate-texp))
+ combined-subs-3
+ pred))
+
+(define (infer:lambda params body env)
+ ;; Each function parameter gets a fresh type variable.
+ (define param-types (fresh-type-variables-for-list params))
+ ;; The type environment is extended with the function parameters.
+ (define env*
+ (fold (lambda (param type env*)
+ (extend-env param (list type) env*))
+ env params param-types))
+ (define-values (body* body-subs body-pred)
+ (infer-exp body env*))
+ (define-values (pred subs)
+ (eval-predicate* body-pred body-subs))
+ (values (texp (list (generalize
+ (function-type (apply-substitutions-to-types param-types
+ body-subs)
+ (texp-types body*))
+ pred env))
+ `(lambda ,params ,body*))
+ subs #t))
+
+(define (infer:primitive-call operator args env)
+ ;; The type signature of primitive functions can be looked up
+ ;; directly in the environment. Primitive functions may be
+ ;; overloaded and need to be instantiated with fresh type variables.
+ (define-values (types operator-pred)
+ (maybe-instantiate (lookup-type operator env)))
+ (define operator-type
+ (match types
+ ((type) type)))
+ ;; Infer the arguments.
+ (define-values (args* arg-subs arg-pred)
+ (infer:list args env))
+ ;; Generate fresh type variables to unify against the return types
+ ;; of the operator.
+ (define return-vars
+ (fresh-type-variables-for-list (function-type-returns operator-type)))
+ (define call-subs
+ (unify operator-type
+ (function-type (map single-type args*)
+ return-vars)))
+ ;; Apply substitutions to the predicate and then eval it, producing
+ ;; a simplified predicate and a set of substitutions.
+ (define-values (pred combined-subs)
+ (eval-predicate* (compose-predicates operator-pred arg-pred)
+ (compose-substitutions arg-subs call-subs)))
+ (values (texp (apply-substitutions-to-types return-vars combined-subs)
+ `(primcall ,operator
+ ,@(map (lambda (arg)
+ (apply-substitutions-to-texp arg
+ combined-subs))
+ args*)))
+ combined-subs
+ pred))
+
+(define (infer:call operator args env)
+ ;; The type signature of primitive functions can be looked up
+ ;; directly in the environment.
+ (define-values (operator* operator-subs operator-pred)
+ (infer-exp operator env))
+ (define env*
+ (apply-substitutions-to-env env operator-subs))
+ ;; Infer the arguments.
+ (define-values (args* arg-subs arg-pred)
+ (infer:list args env*))
+ (define combined-subs-0
+ (compose-substitutions operator-subs arg-subs))
+ ;; Generate fresh type variables to unify against the return types
+ ;; of the operator.
+ (define operator-type (single-type operator*))
+ (define return-vars
+ (fresh-type-variables-for-list
+ (function-type-returns operator-type)))
+ (define call-subs
+ (unify (apply-substitutions-to-type operator-type combined-subs-0)
+ (function-type (apply-substitutions-to-types (map single-type args*)
+ combined-subs-0)
+ return-vars)))
+ ;; Eval predicate.
+ (define-values (pred combined-subs)
+ (eval-predicate* (compose-predicates operator-pred
+ arg-pred)
+ (compose-substitutions combined-subs-0 call-subs)))
+ (values (texp (apply-substitutions-to-types return-vars combined-subs)
+ `(call ,(apply-substitutions-to-texp operator* combined-subs)
+ ,@(map (lambda (arg)
+ (apply-substitutions-to-texp arg
+ combined-subs))
+ args*)))
+ combined-subs
+ pred))
+
+(define (infer:let names exps body env)
+ (define-values (exps* exp-subs exp-pred)
+ (infer:list exps env))
+ (define exp-types (map texp-types exps*))
+ (define env*
+ (fold extend-env
+ (apply-substitutions-to-env env exp-subs)
+ names
+ exp-types))
+ (define-values (body* body-subs body-pred)
+ (infer-exp body env*))
+ (define-values (pred combined-subs)
+ (eval-predicate* (compose-predicates exp-pred body-pred)
+ (compose-substitutions exp-subs body-subs)))
+ (values (texp (texp-types body*)
+ `(let ,(map list names exps*) ,body*))
+ combined-subs
+ pred))
+
+;; Inference returns 3 values:
+;; - a typed expression
+;; - a list of substitutions
+;; - a type predicate
+(define (infer-exp exp env)
(match exp
- (('t types (or (? immediate?) ('var _ _)))
- (success subs))
- (('t types ('if predicate consequent alternate))
- (infer:if types predicate consequent alternate subs success))
- (('t types ('let ((names exps) ...) body))
- (infer:let types names exps body subs success))
- (('t (type) ('lambda (params ...) body))
- (infer:lambda type body subs success))
- (('t types ('values exps ...))
- (infer:values types exps subs success))
- (('t types ('primcall operator args ...))
- (infer:primcall types operator args subs success))
- (('t types ('call operator args ...))
- (infer:call types operator args subs success))
- (('t _ ('outputs (names exps) ...))
- (infer:outputs names exps subs success))
- (('t types ('top-level bindings body))
- (infer:top-level types bindings body subs success))
+ ((? immediate?)
+ (infer:immediate exp))
+ ((? symbol?)
+ (infer:variable exp env))
+ (('if predicate consequent alternate)
+ (infer:if predicate consequent alternate env))
+ (('let ((names exps) ...) body)
+ (infer:let names exps body env))
+ (('lambda (params ...) body)
+ (infer:lambda params body env))
+ ;; (('values exps ...)
+ ;; (infer:values exps env))
+ (('primcall operator args ...)
+ (infer:primitive-call operator args env))
+ (('call operator args ...)
+ (infer:call operator args env))
+ ;; (('outputs (names exps) ...)
+ ;; (infer:outputs names exps env))
+ ;; (('top-level bindings body)
+ ;; (infer:top-level bindings body env))
(_ (error "unknown form" exp))))
-(define (resolve:type-variable var env)
- (let ((type (assq-ref env var)))
- (if type
- (resolve type env)
- var)))
-
-;; TODO: If there are still quantified type variables, scan the entire
-;; program for calls to the function and build an intersection type.
-(define (resolve:for-all type env)
- (let ((quantifiers (filter type-variable?
- (resolve (for-all-type-quantifiers type) env)))
- (ref (resolve (for-all-type-ref type) env)))
- (if (null? quantifiers) ref (for-all-type quantifiers ref))))
-
-(define (resolve:list exps env)
- (map (lambda (exp) (resolve exp env)) exps))
-
-(define (resolve exp env)
- (match exp
- ((? type-variable?)
- (resolve:type-variable exp env))
- ((? for-all-type?)
- (resolve:for-all exp env))
- ((exps ...)
- (resolve:list exps env))
- (_ exp)))
-
;; TODO: Add some kind of context object that is threaded through the
;; inference process so that when a type error occurs we can show the
;; expression that caused it.
(define (infer-types exp stage)
(call-with-unify-rollback
(lambda ()
- (let ((annotated (annotate-exp* exp stage)))
+ (let ((annotated (pk 'annotated (annotate-exp* exp stage))))
(infer annotated
'()
(lambda (subs)