From cbc4c95a0820934600a26329ee58cc5685060dcb Mon Sep 17 00:00:00 2001 From: David Thompson Date: Fri, 3 Feb 2023 22:06:13 -0500 Subject: Qualified types that mostly work. --- chickadee/graphics/seagull.scm | 1325 +++++++++++++++++++++------------------- 1 file changed, 693 insertions(+), 632 deletions(-) diff --git a/chickadee/graphics/seagull.scm b/chickadee/graphics/seagull.scm index 078e4aa..1ae82bb 100644 --- a/chickadee/graphics/seagull.scm +++ b/chickadee/graphics/seagull.scm @@ -18,18 +18,25 @@ ;;; Commentary: ;; ;; The Seagull shading language is a purely functional, statically -;; typed, Scheme-like language that can be compiled to GLSL code. +;; typed, Scheme-like language that can be compiled to GLSL code. The +;; reality of how GPUs work imposes some significant language +;; restrictions, but they are restrictions anyone who writes shader +;; code is already used to. ;; -;; Notable features and restrictions: +;; Features: ;; - Purely functional ;; - Vertex and fragment shader output ;; - Targets multiple GLSL versions ;; - Type inference ;; - Lexical scoping -;; - First-order functions -;; - Nested functions (but no closures!) +;; - Nested functions ;; - Multiple return values ;; +;; Limitations: +;; - No first-class functions +;; - No closures +;; - No recursion +;; ;;; Code: (define-module (chickadee graphics seagull) #:use-module (ice-9 exceptions) @@ -47,6 +54,8 @@ ;; TODO: ;; - Loops ;; - Shader stage linking +;; - Transform for-all functions into overloads +;; - User functions that use overloaded functions need to be overloaded themselves ;;; @@ -763,14 +772,14 @@ ;;; -;;; Typed expressions +;;; Type inference ;;; -;; Typed expressions combine a Seagull expression with the data types -;; that expression returns. This section includes all the kinds of -;; types an expression may have and various transformation procedures -;; to manipulate them. This will be used below to convert the mostly -;; untyped Seagull program into a fully typed one. +;; Walk the expression tree of a type annotated program and solve for +;; all of the type variables using a variant of the Hindley-Milner +;; type inference algorithm. GLSL is a statically typed language, but +;; thanks to type inference the user doesn't have to specify any types +;; expect for shader inputs, outputs, and uniforms. ;; Primitive types: (define (primitive-type name) @@ -825,6 +834,9 @@ (define (fresh-type-variable) (type-variable (unique-type-variable-name))) +(define (fresh-type-variables-for-list lst) + (map (lambda (_x) (fresh-type-variable)) lst)) + (define (type-variable? obj) (match obj (('tvar _) #t) @@ -847,37 +859,43 @@ (match type (('-> _ returns) returns))) +;; Qualified types: +;; (define (qualified-type type predicate) +;; `(qualified ,type ,predicate)) + +;; (define (qualified-type? type) +;; (match type +;; (('qualified _ _) #t) +;; (_ #f))) + +;; (define (qualified-type-ref type) +;; (match type +;; (('qualified type* _) type*))) + +;; (define (qualified-type-predicate type) +;; (match type +;; (('qualified _ predicate) predicate))) + ;; For all types: -(define (for-all-type quantifiers type) - `(for-all ,quantifiers ,type)) +(define (for-all-type quantifiers type predicate) + `(for-all ,quantifiers ,type ,predicate)) (define (for-all-type? obj) (match obj - (('for-all _ _) #t) + (('for-all _ _ _) #t) (_ #f))) (define (for-all-type-quantifiers type) (match type - (('for-all q _) q))) + (('for-all q _ _) q))) (define (for-all-type-ref type) (match type - (('for-all _ t) t))) - -;; Overload types: -(define (overload-type . types) - (match types - (((? function-type?) ...) - `(overload ,types)))) - -(define (overload-type? obj) - (match obj - (('overload _) #t) - (_ #f))) + (('for-all _ t _) t))) -(define (overload-type-ref type) +(define (for-all-type-predicate type) (match type - (('overload types) types))) + (('for-all _ _ p) p))) ;; Outputs type: (define outputs-type '(outputs)) @@ -891,7 +909,8 @@ (or (primitive-type? obj) (type-variable? obj) (function-type? obj) - (overload-type? obj))) + ;; (qualified-type? obj) + (outputs-type? obj))) (define (apply-substitution-to-type type from to) (cond @@ -908,11 +927,13 @@ (map (lambda (return-type) (apply-substitution-to-type return-type from to)) (function-type-returns type)))) - ((overload-type? type) - (apply overload-type - (map (lambda (type) - (apply-substitution-to-type type from to)) - (overload-type-ref type)))) + ;; ((qualified-type? type) + ;; (qualified-type (apply-substitution-to-type + ;; (qualified-type-ref type) from to) + ;; (apply-substitution-to-predicate + ;; (qualified-type-predicate type) from to))) + ((for-all-type? type) + type) (else (error "invalid type" type)))) (define (apply-substitutions-to-type type subs) @@ -926,6 +947,53 @@ (apply-substitutions-to-type type subs)) types)) +(define (apply-substitution-to-env env from to) + (env-fold (lambda (name types env*) + (extend-env name + (map (lambda (type) + (apply-substitution-to-type type from to)) + types) + env*)) + (empty-env) + env)) + +(define (apply-substitutions-to-env env subs) + (env-fold (lambda (from to env*) + (apply-substitution-to-env env* from to)) + env + subs)) + +(define (apply-substitutions-to-texp t subs) + (texp (apply-substitutions-to-types (texp-types t) subs) + (texp-exp t))) + +(define (apply-substitution-to-predicate pred from to) + (match pred + (#t #t) + (('= a b) + `(= ,(apply-substitution-to-type a from to) + ,(apply-substitution-to-type b from to))) + (('and preds ...) + `(and ,@(map (lambda (pred) + (apply-substitution-to-predicate pred from to)) + preds))) + (('or preds ...) + `(or ,@(map (lambda (pred) + (apply-substitution-to-predicate pred from to)) + preds))) + (('when test consequent) + `(when ,(apply-substitution-to-predicate test from to) + ,(apply-substitution-to-predicate consequent from to))) + (('substitute a b) + `(substitute ,(apply-substitution-to-type a from to) + ,(apply-substitution-to-type b from to))))) + +(define (apply-substitutions-to-predicate pred subs) + (env-fold (lambda (from to pred*) + (apply-substitution-to-predicate pred* from to)) + pred + subs)) + ;; Typed expressions: (define (texp types exp) `(t ,types ,exp)) @@ -948,184 +1016,217 @@ ((type) type) (_ (error "expected only 1 type" texp)))) - -;;; -;;; Type annotation -;;; +(define &seagull-type-error + (make-exception-type '&seagull-type-error &error '())) + +(define make-seagull-type-error + (record-constructor &seagull-type-error)) + +(define (seagull-type-error msg args origin) + (raise-exception + (make-exception + (make-seagull-type-error) + (make-exception-with-origin origin) + (make-exception-with-message + (format #f "seagull type error: ~a" msg)) + (make-exception-with-irritants args)))) + +(define (occurs? a b) + (cond + ((and (type-variable? a) (type-variable? b)) + (eq? a b)) + ((and (type-variable? a) (function-type? b)) + (or (occurs? a (function-type-parameters b)) + (occurs? a (function-type-returns b)))) + ((and (type? a) (list? b)) + (any (lambda (b*) (occurs? a b*)) b)) + (else #f))) + +(define (compose-substitutions a b) + (define b* + (map (match-lambda + ((from . to) + (cons from (apply-substitutions-to-type to a)))) + b)) + (define a* + (filter-map (match-lambda + ((from . to) + (if (assq-ref b* from) + #f + (cons from to)))) + a)) + (append a* b*)) -;; Convert untyped Seagull expressions into typed expressions with -;; type variables representing all unknown types. This annotated -;; version of a Seagull program can then be passed to the type -;; inference algorithm to solve for all of the variables. - -(define add/sub-type - (list (overload-type - (function-type (list int-type int-type) - (list int-type)) - (function-type (list float-type float-type) - (list float-type)) - (function-type (list vec2-type vec2-type) - (list vec2-type)) - (function-type (list vec3-type vec3-type) - (list vec3-type)) - (function-type (list vec4-type vec4-type) - (list vec4-type)) - (function-type (list mat3-type mat3-type) - (list mat3-type)) - (function-type (list mat4-type mat4-type) - (list mat4-type))))) - -(define mul-type - (list (overload-type - (function-type (list int-type int-type) - (list int-type)) - (function-type (list float-type float-type) - (list float-type)) - (function-type (list vec2-type vec2-type) - (list vec2-type)) - (function-type (list vec2-type float-type) - (list vec2-type)) - (function-type (list float-type vec2-type) - (list vec2-type)) - (function-type (list vec3-type vec3-type) - (list vec3-type)) - (function-type (list vec3-type float-type) - (list vec3-type)) - (function-type (list float-type vec3-type) - (list vec3-type)) - (function-type (list vec4-type vec4-type) - (list vec4-type)) - (function-type (list vec4-type float-type) - (list vec4-type)) - (function-type (list float-type vec4-type) - (list vec4-type)) - (function-type (list mat3-type mat3-type) - (list mat3-type)) - (function-type (list mat3-type vec3-type) - (list mat3-type)) - (function-type (list vec3-type mat3-type) - (list mat3-type)) - (function-type (list mat4-type mat4-type) - (list mat4-type)) - (function-type (list mat4-type vec4-type) - (list vec4-type)) - (function-type (list vec4-type mat4-type) - (list vec4-type))))) - -(define div-type - (list (overload-type - (function-type (list int-type int-type) - (list int-type)) - (function-type (list float-type float-type) - (list float-type)) - (function-type (list vec2-type vec2-type) - (list vec2-type)) - (function-type (list vec2-type float-type) - (list vec2-type)) - (function-type (list vec3-type vec3-type) - (list vec3-type)) - (function-type (list vec3-type float-type) - (list vec3-type)) - (function-type (list vec4-type vec4-type) - (list vec4-type)) - (function-type (list vec4-type float-type) - (list vec4-type)) - (function-type (list mat3-type float-type) - (list mat3-type)) - (function-type (list mat4-type float-type) - (list mat4-type))))) - -(define comparison-type - (list (overload-type - (function-type (list int-type int-type) - (list bool-type)) - (function-type (list float-type float-type) - (list bool-type))))) - -(define make-vec2-type - (list (function-type (list float-type float-type) - (list vec2-type)))) - -(define make-vec3-type - (list (overload-type - (function-type (list float-type float-type float-type) - (list vec3-type)) - (function-type (list vec2-type float-type) - (list vec3-type)) - (function-type (list float-type vec2-type) - (list vec3-type))))) - -(define make-vec4-type - (list (overload-type - (function-type (list float-type float-type float-type float-type) - (list vec4-type)) - (function-type (list vec2-type float-type float-type) - (list vec4-type)) - (function-type (list float-type vec2-type float-type) - (list vec4-type)) - (function-type (list float-type float-type vec2-type) - (list vec4-type)) - (function-type (list vec3-type float-type) - (list vec4-type)) - (function-type (list float-type vec3-type) - (list vec4-type))))) - -(define abs-type - (list (overload-type - (function-type (list int-type) (list int-type)) - (function-type (list float-type) (list float-type))))) - -(define sqrt-type - (list (overload-type - (function-type (list int-type) (list int-type)) - (function-type (list float-type) (list float-type))))) - -(define min/max-type - (list (overload-type - (function-type (list int-type int-type) (list int-type)) - (function-type (list float-type float-type) (list float-type))))) - -(define trigonometry-type - (list (function-type (list float-type) (list float-type)))) - -(define clamp/mix-type - (list (overload-type - (function-type (list int-type int-type int-type) - (list int-type)) - (function-type (list float-type float-type float-type) - (list float-type))))) +;; (define add/sub-type +;; (list (overload-type +;; (function-type (list int-type int-type) +;; (list int-type)) +;; (function-type (list float-type float-type) +;; (list float-type)) +;; (function-type (list vec2-type vec2-type) +;; (list vec2-type)) +;; (function-type (list vec3-type vec3-type) +;; (list vec3-type)) +;; (function-type (list vec4-type vec4-type) +;; (list vec4-type)) +;; (function-type (list mat3-type mat3-type) +;; (list mat3-type)) +;; (function-type (list mat4-type mat4-type) +;; (list mat4-type))))) + +;; (define mul-type +;; (list (overload-type +;; (function-type (list int-type int-type) +;; (list int-type)) +;; (function-type (list float-type float-type) +;; (list float-type)) +;; (function-type (list vec2-type vec2-type) +;; (list vec2-type)) +;; (function-type (list vec2-type float-type) +;; (list vec2-type)) +;; (function-type (list float-type vec2-type) +;; (list vec2-type)) +;; (function-type (list vec3-type vec3-type) +;; (list vec3-type)) +;; (function-type (list vec3-type float-type) +;; (list vec3-type)) +;; (function-type (list float-type vec3-type) +;; (list vec3-type)) +;; (function-type (list vec4-type vec4-type) +;; (list vec4-type)) +;; (function-type (list vec4-type float-type) +;; (list vec4-type)) +;; (function-type (list float-type vec4-type) +;; (list vec4-type)) +;; (function-type (list mat3-type mat3-type) +;; (list mat3-type)) +;; (function-type (list mat3-type vec3-type) +;; (list mat3-type)) +;; (function-type (list vec3-type mat3-type) +;; (list mat3-type)) +;; (function-type (list mat4-type mat4-type) +;; (list mat4-type)) +;; (function-type (list mat4-type vec4-type) +;; (list vec4-type)) +;; (function-type (list vec4-type mat4-type) +;; (list vec4-type))))) + +;; (define div-type +;; (list (overload-type +;; (function-type (list int-type int-type) +;; (list int-type)) +;; (function-type (list float-type float-type) +;; (list float-type)) +;; (function-type (list vec2-type vec2-type) +;; (list vec2-type)) +;; (function-type (list vec2-type float-type) +;; (list vec2-type)) +;; (function-type (list vec3-type vec3-type) +;; (list vec3-type)) +;; (function-type (list vec3-type float-type) +;; (list vec3-type)) +;; (function-type (list vec4-type vec4-type) +;; (list vec4-type)) +;; (function-type (list vec4-type float-type) +;; (list vec4-type)) +;; (function-type (list mat3-type float-type) +;; (list mat3-type)) +;; (function-type (list mat4-type float-type) +;; (list mat4-type))))) + +;; (define comparison-type +;; (list (overload-type +;; (function-type (list int-type int-type) +;; (list bool-type)) +;; (function-type (list float-type float-type) +;; (list bool-type))))) + +;; (define make-vec2-type +;; (list (function-type (list float-type float-type) +;; (list vec2-type)))) + +;; (define make-vec3-type +;; (list (overload-type +;; (function-type (list float-type float-type float-type) +;; (list vec3-type)) +;; (function-type (list vec2-type float-type) +;; (list vec3-type)) +;; (function-type (list float-type vec2-type) +;; (list vec3-type))))) + +;; (define make-vec4-type +;; (list (overload-type +;; (function-type (list float-type float-type float-type float-type) +;; (list vec4-type)) +;; (function-type (list vec2-type float-type float-type) +;; (list vec4-type)) +;; (function-type (list float-type vec2-type float-type) +;; (list vec4-type)) +;; (function-type (list float-type float-type vec2-type) +;; (list vec4-type)) +;; (function-type (list vec3-type float-type) +;; (list vec4-type)) +;; (function-type (list float-type vec3-type) +;; (list vec4-type))))) + +;; (define abs-type +;; (list (overload-type +;; (function-type (list int-type) (list int-type)) +;; (function-type (list float-type) (list float-type))))) + +;; (define sqrt-type +;; (list (overload-type +;; (function-type (list int-type) (list int-type)) +;; (function-type (list float-type) (list float-type))))) + +;; (define min/max-type +;; (list (overload-type +;; (function-type (list int-type int-type) (list int-type)) +;; (function-type (list float-type float-type) (list float-type))))) + +;; (define trigonometry-type +;; (list (function-type (list float-type) (list float-type)))) + +;; (define clamp/mix-type +;; (list (overload-type +;; (function-type (list int-type int-type int-type) +;; (list int-type)) +;; (function-type (list float-type float-type float-type) +;; (list float-type))))) (define (top-level-type-env stage) - `((+ . ,add/sub-type) - (- . ,add/sub-type) - (* . ,mul-type) - (/ . ,div-type) - (= . ,comparison-type) - (< . ,comparison-type) - (<= . ,comparison-type) - (> . ,comparison-type) - (>= . ,comparison-type) - (vec2 . ,make-vec2-type) - (vec3 . ,make-vec3-type) - (vec4 . ,make-vec4-type) - (not ,(function-type (list bool-type) (list bool-type))) - (int->float ,(function-type (list int-type) (list float-type))) - (float->int ,(function-type (list float-type) (list int-type))) - (abs . ,abs-type) - (sqrt . ,sqrt-type) - (min . ,min/max-type) - (max . ,min/max-type) - (sin . ,trigonometry-type) - (cos . ,trigonometry-type) - (tan . ,trigonometry-type) - (clamp . ,clamp/mix-type) - (mix . ,clamp/mix-type) - ,@(case stage - ((vertex) - `((gl-position ,vec4-type))) - ((fragment) - `((texture-2d ,(function-type (list sampler-2d-type vec2-type) - (list vec4-type)))))))) + '() + ;; `((+ . ,add/sub-type) + ;; (- . ,add/sub-type) + ;; (* . ,mul-type) + ;; (/ . ,div-type) + ;; (= . ,comparison-type) + ;; (< . ,comparison-type) + ;; (<= . ,comparison-type) + ;; (> . ,comparison-type) + ;; (>= . ,comparison-type) + ;; (vec2 . ,make-vec2-type) + ;; (vec3 . ,make-vec3-type) + ;; (vec4 . ,make-vec4-type) + ;; (not ,(function-type (list bool-type) (list bool-type))) + ;; (int->float ,(function-type (list int-type) (list float-type))) + ;; (float->int ,(function-type (list float-type) (list int-type))) + ;; (abs . ,abs-type) + ;; (sqrt . ,sqrt-type) + ;; (min . ,min/max-type) + ;; (max . ,min/max-type) + ;; (sin . ,trigonometry-type) + ;; (cos . ,trigonometry-type) + ;; (tan . ,trigonometry-type) + ;; (clamp . ,clamp/mix-type) + ;; (mix . ,clamp/mix-type) + ;; ,@(case stage + ;; ((vertex) + ;; `((gl-position ,vec4-type))) + ;; ((fragment) + ;; `((texture-2d ,(function-type (list sampler-2d-type vec2-type) + ;; (list vec4-type))))))) + ) (define (lookup-type name env) (let ((type (lookup name env))) @@ -1175,13 +1276,40 @@ '() env))) -(define (generalize type env) +(define (free-variables-in-predicate pred) + (match pred + (#t '()) + (('= a b) + (append (free-variables-in-type a) + (free-variables-in-type b))) + (('and preds ...) + (append-map (lambda (pred) + (free-variables-in-predicate pred)) + preds)) + (('or preds ...) + (append-map (lambda (pred) + (free-variables-in-predicate pred)) + preds)) + (('when test consequent) + (append (free-variables-in-predicate test) + (free-variables-in-predicate consequent))) + (('substitute a b) + (append (free-variables-in-type a) + (free-variables-in-type b))))) + +;; Quantified variables: +;; - Unused parameters +;; - Parameters that appear free in the return type +;; - Parameters that are used in overloaded primcalls +(define (generalize type pred env) (if (function-type? type) - (match (difference (free-variables-in-type type) + (match (difference (delete-duplicates + (append (free-variables-in-type type) + (free-variables-in-predicate pred))) (free-variables-in-env env)) (() type) ((quantifiers ...) - (for-all-type quantifiers type))) + (for-all-type quantifiers type pred))) type)) (define (instantiate for-all) @@ -1190,451 +1318,384 @@ (extend-env var (fresh-type-variable) env)) (empty-env) (for-all-type-quantifiers for-all))) - (apply-substitutions-to-type (for-all-type-ref for-all) subs)) - -(define (fresh-type-variables-for-list lst) - (map (lambda (_x) (fresh-type-variable)) lst)) - -(define (annotate:list exps env) - (map (lambda (exp) (annotate-exp exp env)) exps)) - -(define (annotate:if predicate consequent alternate env) - (define consequent* (annotate-exp consequent env)) - (texp (texp-types consequent*) - `(if ,(annotate-exp predicate env) - ,consequent* - ,(annotate-exp alternate env)))) - -(define (annotate:let names exps body env) - (define exps* (annotate:list exps env)) - (define exp-types (map texp-types exps*)) - (define env* (fold extend-env env names exp-types)) - (define body* (annotate-exp body env*)) - (texp (texp-types body*) - `(let ,(map list names exps*) ,body*))) - -(define (annotate:lambda params body env) - ;; Each function parameter gets a fresh type variable. - (define param-types (fresh-type-variables-for-list params)) - ;; The type environment is extended with the function parameters. - (define env* - (fold (lambda (param type env*) - (extend-env param (list type) env*)) - env params param-types)) - ;; The body is annotated in the new environment. - (define body* (annotate-exp body env*)) - (texp (list (function-type param-types (texp-types body*))) - `(lambda ,params ,body*))) - -(define (annotate:values exps env) - (define exps* (annotate:list exps env)) - (texp (map single-type exps*) - `(values ,@exps*))) - -(define (annotate:primitive-call operator args env) - ;; The type signature of primitive functions can be looked up - ;; directly in the environment. - (define operator-types (lookup-type operator env)) - (define operator* (texp operator-types operator)) - (define args* (annotate:list args env)) - (texp (fresh-type-variables-for-list (list operator*)) - `(primcall ,operator* ,@args*))) - -(define (annotate:call operator args env) - (define operator* (annotate-exp operator env)) - (define args* (annotate:list args env)) - (texp (fresh-type-variables-for-list - (function-type-returns - (single-type operator*))) - `(call ,operator* ,@args*))) - -(define* (annotate:top-level bindings body env #:optional (result '())) - (match bindings - (() - (let ((body* (annotate-exp body env))) - (texp (texp-types body*) `(top-level ,(reverse result) ,body*)))) - ((('function name exp) . rest) - (define exp* - (let ((x (annotate-exp exp env))) - ;; Function types must be generalized so that functions like - ;; (lambda (x) x) can truly be applied to any type. - (texp (list (generalize (single-type x) env)) - (texp-exp x)))) - (define env* (extend-env name (texp-types exp*) env)) - (define result* (cons `(function ,name ,exp*) result)) - (annotate:top-level rest body env* result*)) - ((((? top-level-qualifier? qualifier) type-name name) . rest) - (define env* (extend-env name (list (type-name->type type-name)) env)) - (define result* (cons (list qualifier type-name name) result)) - (annotate:top-level rest body env* result*)))) - -(define (annotate:outputs names exps env) - (texp (list outputs-type) - `(outputs - ,@(map (lambda (name exp) - (list (texp (lookup name env) name) - (annotate-exp exp env))) - names exps)))) - -(define (annotate-exp exp env) - (match exp - ((? exact-integer?) - (texp (list int-type) exp)) - ((? float?) - (texp (list float-type) exp)) - ((? boolean?) - (texp (list bool-type) exp)) - (('var var _) - (texp (lookup-type var env) exp)) - (('if predicate consequent alternate) - (annotate:if predicate consequent alternate env)) - (('let ((names exps) ...) body) - (annotate:let names exps body env)) - (('lambda (params ...) body) - (annotate:lambda params body env)) - (('values exps ...) - (annotate:values exps env)) - (('primcall operator args ...) - (annotate:primitive-call operator args env)) - (('call operator args ...) - (annotate:call operator args env)) - (('outputs (names exps) ...) - (annotate:outputs names exps env)) - (('top-level bindings body) - (annotate:top-level bindings body env)))) - -(define (annotate-exp* exp stage) - (annotate-exp exp (top-level-type-env stage))) - - -;;; -;;; Type inference -;;; - -;; Walk the expression tree of a type annotated program and solve for -;; all of the type variables using a variant of the Hindley-Milner -;; type inference algorithm. GLSL is a statically typed language, but -;; thanks to type inference the user doesn't have to specify any types -;; expect for shader inputs, outputs, and uniforms. - -(define &seagull-type-error - (make-exception-type '&seagull-type-error &error '())) - -(define make-seagull-type-error - (record-constructor &seagull-type-error)) - -(define (seagull-type-error msg args origin) - (raise-exception - (make-exception - (make-seagull-type-error) - (make-exception-with-origin origin) - (make-exception-with-message - (format #f "seagull type error: ~a" msg)) - (make-exception-with-irritants args)))) - -(define (occurs? a b) + (values + (apply-substitutions-to-type (for-all-type-ref for-all) subs) + (apply-substitutions-to-predicate (for-all-type-predicate for-all) subs))) + +(define (maybe-instantiate types) + (define types+preds + (map (lambda (type) + (if (for-all-type? type) + (call-with-values (lambda () (instantiate type)) list) + (list type #t))) + types)) + (values (map first types+preds) + (reduce compose-predicates #t (map second types+preds)))) + +;; (define (strip-qualifier type) +;; (if (qualified-type? type) +;; (qualified-type-ref type) +;; type)) + +;; (define (strip-qualifiers types) +;; (map strip-qualifier types)) + +;; (define (predicate-for-type type) +;; (if (qualified-type? type) +;; (qualified-type-predicate type) +;; #t)) + +(define (compose-predicates a b) (cond - ((and (type-variable? a) (type-variable? b)) - (eq? a b)) - ((and (type-variable? a) (function-type? b)) - (or (occurs? a (function-type-parameters b)) - (occurs? a (function-type-returns b)))) - ((and (type? a) (list? b)) - (any (lambda (b*) (occurs? a b*)) b)) - (else #f))) - -(define (compose-substitutions a b) - (define b* - (map (match-lambda - ((from . to) - (cons from (apply-substitutions-to-type to a)))) - b)) - (define a* - (filter-map (match-lambda - ((from . to) - (if (assq-ref b* from) - #f - (cons from to)))) - a)) - (append a* b*)) - -(define unify-prompt-tag (make-prompt-tag 'unify)) - -(define (call-with-unify-rollback thunk handler) - (call-with-prompt unify-prompt-tag - thunk - (lambda (k args) - (apply handler args)))) - -(define (unify:fail . args) - (abort-to-prompt unify-prompt-tag args)) - -(define (unify:primitives a b success) + ((and (eq? a #t) (eq? b #t)) + #t) + ((eq? a #t) + b) + ((eq? b #t) + a) + (else + `(and ,a ,b)))) + +(define (compose-predicates* preds) + (reduce (lambda (pred memo) + (compose-predicates pred memo)) + #t + preds)) + +(define (compose-predicates-for-types types) + (compose-predicates* (map predicate-for-type types))) + +;; Produces a simplified predicate and a new set of substitutions for +;; predicates that have been satisfied and simplified to #t. +(define (eval-predicate pred) + (match pred + (#t (values #t '())) + ((or ('= (? type-variable?) _) + ('= _ (? type-variable?))) + (values pred '())) + (('= a b) + (values (equal? a b) '())) + (('or preds ...) + (let loop ((preds preds)) + (match preds + (() (values #f '())) + ((pred* . rest) + (define-values (new-pred subs) + (eval-predicate pred*)) + (match new-pred + (#t (values #t subs)) + (#f (eval-predicate `(or ,@rest))) + (_ (values `(or ,new-pred ,@rest) '()))))))) + (('and preds ...) + (let loop ((preds preds)) + (match preds + (() (values #t '())) + ((pred* . rest) + (define-values (new-pred subs) + (eval-predicate pred*)) + (match new-pred + (#t + (let () + (define-values (new-pred* subs*) + (eval-predicate `(and ,@rest))) + (values new-pred* (compose-substitutions subs subs*)))) + (#f (values #f '())) + (_ (values `(and ,new-pred ,@rest) '()))))))) + (('when test consequent) + (define-values (new-test subs) + (eval-predicate test)) + (match new-test + (#t + (let () + (define-values (new-consequent subs*) + (eval-predicate consequent)) + (values new-consequent (compose-substitutions subs subs*)))) + (#f (values #f '())) + (_ (values `(when ,new-test ,consequent) '())))) + (('substitute a b) + (values #t (list (cons a b)))))) + +(define (eval-predicate* pred subs) + (define-values (new-pred pred-subs) + (eval-predicate + (apply-substitutions-to-predicate pred subs))) + ;; TODO: Get information about *why* the predicate failed. + (unless new-pred + (error "predicate failure")) + (values new-pred (compose-substitutions subs pred-subs))) + +(define (unify:primitives a b) (if (equal? a b) - (success '()) - (unify:fail "primitive type mismatch" a b))) + '() + (error "primitive type mismatch" a b))) -(define (unify:variable a b success) +(define (unify:variable a b) (cond ((eq? a b) - (success '())) + '()) ((occurs? a b) - (unify:fail "type contains reference to itself" a b)) + (error "type contains reference to itself" a b)) (else - (success (list (cons a b)))))) - -(define (unify:functions a b success) - (unify (function-type-parameters a) - (function-type-parameters b) - (lambda (sub0) - (unify (apply-substitutions-to-types (function-type-returns a) - sub0) - (apply-substitutions-to-types (function-type-returns b) - sub0) - (lambda (sub1) - (success (compose-substitutions sub0 sub1))))))) - -(define (unify:overload a b success) - (define (try-unify functions) - (match functions - (() - (unify:fail "no matching overload" a b)) - ((function . rest) - (call-with-unify-rollback - (lambda () - (unify function b success)) - (lambda _ - (try-unify rest)))))) - (try-unify (overload-type-ref a))) - -(define (unify:lists a rest-a b rest-b success) - (unify a b - (lambda (sub0) - (unify (apply-substitutions-to-types rest-a sub0) - (apply-substitutions-to-types rest-b sub0) - (lambda (sub1) - (success (compose-substitutions sub0 sub1))))))) - -(define (unify a b success) + (list (cons a b))))) + +(define (unify:functions a b) + (define param-subs + (unify (function-type-parameters a) + (function-type-parameters b))) + (define return-subs + (unify (apply-substitutions-to-types (function-type-returns a) + param-subs) + (apply-substitutions-to-types (function-type-returns b) + param-subs))) + (compose-substitutions param-subs return-subs)) + +(define (unify:lists a rest-a b rest-b) + (define sub-first (unify a b)) + (define sub-rest + (unify (apply-substitutions-to-types rest-a sub-first) + (apply-substitutions-to-types rest-b sub-first))) + (compose-substitutions sub-first sub-rest)) + +(define (unify a b) (match (list a b) (((? primitive-type? a) (? primitive-type? b)) - (unify:primitives a b success)) + (unify:primitives a b)) ((or ((? type-variable? a) b) (b (? type-variable? a))) - (unify:variable a b success)) + (unify:variable a b)) (((? function-type? a) (? function-type? b)) - (unify:functions a b success)) - ((a (? overload-type? b)) - (unify:overload b a success)) + (unify:functions a b)) (((? outputs-type?) (? outputs-type?)) - (success '())) + '()) ((() ()) - (success '())) + '()) (((a rest-a ...) (b rest-b ...)) - (unify:lists a rest-a b rest-b success)) + (unify:lists a rest-a b rest-b)) (_ - (unify:fail "type mismatch" a b)))) - -(define (infer:list exps subs success) - (match exps - (() (success subs)) - ((exp . rest) - (infer exp - subs - (lambda (sub0) - (infer:list rest - sub0 - (lambda (sub1) - (success - (compose-substitutions sub0 sub1))))))))) - -(define (infer:if types predicate consequent alternate subs success) - (infer predicate - subs - (lambda (sub0) - (unify (single-type predicate) - bool-type - (lambda (sub1) - (define sub2 (compose-substitutions sub0 sub1)) - (infer consequent - sub2 - (lambda (sub3) - (infer alternate - sub3 - (lambda (sub4) - (unify (apply-substitutions-to-types - (texp-types consequent) - sub4) - (apply-substitutions-to-types - (texp-types alternate) - sub4) - (lambda (sub5) - (success - (compose-substitutions - sub4 sub5))))))))))))) - -(define (infer:lambda type body subs success) - (define type* - (if (for-all-type? type) - (for-all-type-ref type) - type)) - (infer body - subs - (lambda (sub0) - (unify (apply-substitutions-to-types (texp-types body) sub0) - (function-type-returns type*) - (lambda (sub1) - (success - (compose-substitutions sub0 sub1))))))) - -(define (infer:values types exps subs success) - (infer:list exps subs success)) - -(define (infer:call types operator args subs success) - (infer operator - subs - (lambda (sub0) - (infer:list - args - sub0 - (lambda (sub1) - ;; Check if function call has the proper number of - ;; arguments. - (let* ((k (length args)) - (l (length (function-type-parameters - (single-type operator))))) - (unless (= k l) - (seagull-type-error - (format #f "expected ~a arguments, got ~a" l k) - '() - infer:call))) - (unify (apply-substitutions-to-type - (function-type (map single-type args) - types) - sub1) - (apply-substitutions-to-type - (single-type operator) - sub1) - (lambda (sub2) - (success (compose-substitutions sub1 sub2))))))))) - -(define (infer:primcall types operator args subs success) - (infer:list - args - subs - (lambda (sub0) - (unify (apply-substitutions-to-type - (function-type (map single-type args) - types) - sub0) - (apply-substitutions-to-type - (single-type operator) - sub0) - (lambda (sub1) - (success - (compose-substitutions sub0 sub1))))))) - -(define (infer:let types names exps body subs success) - (infer:list - exps - subs - (lambda (sub0) - (infer body - sub0 - (lambda (sub1) - (unify (apply-substitutions-to-types types sub1) - (apply-substitutions-to-types (texp-types body) sub1) - (lambda (sub2) - (success - (compose-substitutions sub1 sub2))))))))) - -(define (infer:top-level types bindings body subs success) - (define exps - (filter-map (match-lambda - (('function _ exp) exp) - (_ #f)) - bindings)) - (infer:list - exps - subs - (lambda (sub0) - (infer body - sub0 - (lambda (sub1) - (unify (apply-substitutions-to-types types sub1) - (apply-substitutions-to-types (texp-types body) sub1) - (lambda (sub2) - (success - (compose-substitutions sub1 sub2))))))))) - -(define (infer:outputs names exps subs success) - (define output-types (map single-type names)) - (infer:list - exps - subs - (lambda (sub0) - (unify output-types - (apply-substitutions-to-types (map single-type exps) sub0) - (lambda (sub1) - (success - (compose-substitutions sub0 sub1))))))) - -(define (infer exp subs success) + (error "type mismatch" a b)))) + +(define (infer:immediate x) + (values (texp (list (cond + ((exact-integer? x) + int-type) + ((float? x) + float-type) + ((boolean? x) + bool-type))) + x) + '() + #t)) + +(define (infer:variable name env) + (define-values (types pred) + (maybe-instantiate (lookup-type name env))) + (values (texp types name) + '() + pred)) + +(define (infer:list exps env) + (let loop ((exps exps) + (texps '()) + (subs '()) + (pred #t)) + (match exps + (() + (values (reverse texps) subs pred)) + ((exp . rest) + (define-values (texp subs* pred*) + (infer-exp exp env)) + (define-values (new-pred combined-subs) + (eval-predicate* (compose-predicates pred pred*) + (compose-substitutions subs subs*))) + (loop rest + (cons texp texps) + combined-subs + new-pred))))) + +(define (infer:if predicate consequent alternate env) + ;; Infer predicate types and unify it with the boolean type. + (define-values (predicate-texp predicate-subs predicate-pred) + (infer-exp predicate env)) + (define predicate-unify-subs + (unify (texp-types predicate-texp) (list bool-type))) + ;; Combine the substitutions and apply them to the environment. + (define combined-subs-0 + (compose-substitutions predicate-subs predicate-unify-subs)) + (define env0 + (apply-substitutions-to-env env combined-subs-0)) + ;; Infer consequent and alternate types and unify them against each + ;; other. Each branch of an 'if' should have the same type. + (define-values (consequent-texp consequent-subs consequent-pred) + (infer-exp consequent env0)) + (define combined-subs-1 + (compose-substitutions combined-subs-0 consequent-subs)) + (define env1 + (apply-substitutions-to-env env0 consequent-subs)) + (define-values (alternate-texp alternate-subs alternate-pred) + (infer-exp alternate env1)) + (define combined-subs-2 + (compose-substitutions combined-subs-1 alternate-subs)) + ;; Eval combined predicate. + (define-values (pred combined-subs-3) + (eval-predicate* (compose-predicates predicate-pred + (compose-predicates consequent-pred + alternate-pred)) + combined-subs-2)) + ;; ;; Apply final set of substitutions to the types of both branches. + (define consequent-texp* + (apply-substitutions-to-texp consequent-texp combined-subs-3)) + (define alternate-texp* + (apply-substitutions-to-texp alternate-texp combined-subs-3)) + (values (texp (texp-types consequent-texp) + `(if ,predicate-texp ,consequent-texp ,alternate-texp)) + combined-subs-3 + pred)) + +(define (infer:lambda params body env) + ;; Each function parameter gets a fresh type variable. + (define param-types (fresh-type-variables-for-list params)) + ;; The type environment is extended with the function parameters. + (define env* + (fold (lambda (param type env*) + (extend-env param (list type) env*)) + env params param-types)) + (define-values (body* body-subs body-pred) + (infer-exp body env*)) + (define-values (pred subs) + (eval-predicate* body-pred body-subs)) + (values (texp (list (generalize + (function-type (apply-substitutions-to-types param-types + body-subs) + (texp-types body*)) + pred env)) + `(lambda ,params ,body*)) + subs #t)) + +(define (infer:primitive-call operator args env) + ;; The type signature of primitive functions can be looked up + ;; directly in the environment. Primitive functions may be + ;; overloaded and need to be instantiated with fresh type variables. + (define-values (types operator-pred) + (maybe-instantiate (lookup-type operator env))) + (define operator-type + (match types + ((type) type))) + ;; Infer the arguments. + (define-values (args* arg-subs arg-pred) + (infer:list args env)) + ;; Generate fresh type variables to unify against the return types + ;; of the operator. + (define return-vars + (fresh-type-variables-for-list (function-type-returns operator-type))) + (define call-subs + (unify operator-type + (function-type (map single-type args*) + return-vars))) + ;; Apply substitutions to the predicate and then eval it, producing + ;; a simplified predicate and a set of substitutions. + (define-values (pred combined-subs) + (eval-predicate* (compose-predicates operator-pred arg-pred) + (compose-substitutions arg-subs call-subs))) + (values (texp (apply-substitutions-to-types return-vars combined-subs) + `(primcall ,operator + ,@(map (lambda (arg) + (apply-substitutions-to-texp arg + combined-subs)) + args*))) + combined-subs + pred)) + +(define (infer:call operator args env) + ;; The type signature of primitive functions can be looked up + ;; directly in the environment. + (define-values (operator* operator-subs operator-pred) + (infer-exp operator env)) + (define env* + (apply-substitutions-to-env env operator-subs)) + ;; Infer the arguments. + (define-values (args* arg-subs arg-pred) + (infer:list args env*)) + (define combined-subs-0 + (compose-substitutions operator-subs arg-subs)) + ;; Generate fresh type variables to unify against the return types + ;; of the operator. + (define operator-type (single-type operator*)) + (define return-vars + (fresh-type-variables-for-list + (function-type-returns operator-type))) + (define call-subs + (unify (apply-substitutions-to-type operator-type combined-subs-0) + (function-type (apply-substitutions-to-types (map single-type args*) + combined-subs-0) + return-vars))) + ;; Eval predicate. + (define-values (pred combined-subs) + (eval-predicate* (compose-predicates operator-pred + arg-pred) + (compose-substitutions combined-subs-0 call-subs))) + (values (texp (apply-substitutions-to-types return-vars combined-subs) + `(call ,(apply-substitutions-to-texp operator* combined-subs) + ,@(map (lambda (arg) + (apply-substitutions-to-texp arg + combined-subs)) + args*))) + combined-subs + pred)) + +(define (infer:let names exps body env) + (define-values (exps* exp-subs exp-pred) + (infer:list exps env)) + (define exp-types (map texp-types exps*)) + (define env* + (fold extend-env + (apply-substitutions-to-env env exp-subs) + names + exp-types)) + (define-values (body* body-subs body-pred) + (infer-exp body env*)) + (define-values (pred combined-subs) + (eval-predicate* (compose-predicates exp-pred body-pred) + (compose-substitutions exp-subs body-subs))) + (values (texp (texp-types body*) + `(let ,(map list names exps*) ,body*)) + combined-subs + pred)) + +;; Inference returns 3 values: +;; - a typed expression +;; - a list of substitutions +;; - a type predicate +(define (infer-exp exp env) (match exp - (('t types (or (? immediate?) ('var _ _))) - (success subs)) - (('t types ('if predicate consequent alternate)) - (infer:if types predicate consequent alternate subs success)) - (('t types ('let ((names exps) ...) body)) - (infer:let types names exps body subs success)) - (('t (type) ('lambda (params ...) body)) - (infer:lambda type body subs success)) - (('t types ('values exps ...)) - (infer:values types exps subs success)) - (('t types ('primcall operator args ...)) - (infer:primcall types operator args subs success)) - (('t types ('call operator args ...)) - (infer:call types operator args subs success)) - (('t _ ('outputs (names exps) ...)) - (infer:outputs names exps subs success)) - (('t types ('top-level bindings body)) - (infer:top-level types bindings body subs success)) + ((? immediate?) + (infer:immediate exp)) + ((? symbol?) + (infer:variable exp env)) + (('if predicate consequent alternate) + (infer:if predicate consequent alternate env)) + (('let ((names exps) ...) body) + (infer:let names exps body env)) + (('lambda (params ...) body) + (infer:lambda params body env)) + ;; (('values exps ...) + ;; (infer:values exps env)) + (('primcall operator args ...) + (infer:primitive-call operator args env)) + (('call operator args ...) + (infer:call operator args env)) + ;; (('outputs (names exps) ...) + ;; (infer:outputs names exps env)) + ;; (('top-level bindings body) + ;; (infer:top-level bindings body env)) (_ (error "unknown form" exp)))) -(define (resolve:type-variable var env) - (let ((type (assq-ref env var))) - (if type - (resolve type env) - var))) - -;; TODO: If there are still quantified type variables, scan the entire -;; program for calls to the function and build an intersection type. -(define (resolve:for-all type env) - (let ((quantifiers (filter type-variable? - (resolve (for-all-type-quantifiers type) env))) - (ref (resolve (for-all-type-ref type) env))) - (if (null? quantifiers) ref (for-all-type quantifiers ref)))) - -(define (resolve:list exps env) - (map (lambda (exp) (resolve exp env)) exps)) - -(define (resolve exp env) - (match exp - ((? type-variable?) - (resolve:type-variable exp env)) - ((? for-all-type?) - (resolve:for-all exp env)) - ((exps ...) - (resolve:list exps env)) - (_ exp))) - ;; TODO: Add some kind of context object that is threaded through the ;; inference process so that when a type error occurs we can show the ;; expression that caused it. (define (infer-types exp stage) (call-with-unify-rollback (lambda () - (let ((annotated (annotate-exp* exp stage))) + (let ((annotated (pk 'annotated (annotate-exp* exp stage)))) (infer annotated '() (lambda (subs) -- cgit v1.2.3