summaryrefslogtreecommitdiff
path: root/chickadee/graphics/seagull/pass-infer.scm
diff options
context:
space:
mode:
Diffstat (limited to 'chickadee/graphics/seagull/pass-infer.scm')
-rw-r--r--chickadee/graphics/seagull/pass-infer.scm182
1 files changed, 182 insertions, 0 deletions
diff --git a/chickadee/graphics/seagull/pass-infer.scm b/chickadee/graphics/seagull/pass-infer.scm
new file mode 100644
index 0000000..0ac204e
--- /dev/null
+++ b/chickadee/graphics/seagull/pass-infer.scm
@@ -0,0 +1,182 @@
+;;; Chickadee Game Toolkit
+;;; Copyright © 2023 David Thompson <dthompson2@worcester.edu>
+;;;
+;;; Licensed under the Apache License, Version 2.0 (the "License");
+;;; you may not use this file except in compliance with the License.
+;;; You may obtain a copy of the License at
+;;;
+;;; http://www.apache.org/licenses/LICENSE-2.0
+;;;
+;;; Unless required by applicable law or agreed to in writing, software
+;;; distributed under the License is distributed on an "AS IS" BASIS,
+;;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;;; See the License for the specific language governing permissions and
+;;; limitations under the License.
+
+;; Walk the CPS control flow graph and solve for all of the type
+;; variables using a variant of the Hindley-Milner type inference
+;; algorithm extended to handle qualified types (types with
+;; predicates.) GLSL is a statically typed language, but thanks to
+;; type inference the user doesn't have to specify any types expect
+;; for shader inputs, outputs, and uniforms.
+
+;; Dedicated to the memory of Ian Denhardt (zenhack), who pointed me
+;; towards type predicate systems and answered my questions on the
+;; fediverse. That was the critical piece that I, someone who knows
+;; little about static typing, needed to extend traditional
+;; Hindley-Milner type inference to work with GLSL's function
+;; overloading.
+(define-module (chickadee graphics seagull pass-infer)
+ #:use-module (chickadee graphics seagull cps)
+ #:use-module (chickadee graphics seagull primitives)
+ #:use-module (chickadee graphics seagull syntax)
+ #:use-module (chickadee graphics seagull types)
+ #:use-module (chickadee graphics seagull utils)
+ #:use-module (ice-9 exceptions)
+ #:use-module (ice-9 match)
+ #:use-module (language cps intmap)
+ #:use-module (srfi srfi-1)
+ #:use-module (srfi srfi-11)
+ #:export (infer))
+
+(define bool (lookup-type 'bool))
+(define int (lookup-type 'int))
+
+(define &seagull-type-error
+ (make-exception-type '&seagull-type-error &error '()))
+
+(define make-seagull-type-error
+ (record-constructor &seagull-type-error))
+
+(define (seagull-type-error msg args src origin)
+ (raise-exception
+ (make-exception
+ (make-seagull-type-error)
+ (make-exception-with-origin origin)
+ (make-exception-with-message
+ (format #f "seagull type error at ~a: ~a"
+ (sourcev->string src)
+ msg))
+ (make-exception-with-irritants args))))
+
+(define (infer:values vals graph env subs)
+ (values graph
+ (map (lambda (var) (lookup var env)) vals)
+ subs))
+
+(define (infer:assignment var val graph env subs)
+ (let ((subs* (unify (lookup var env) (lookup val env))))
+ (values graph '() (compose-substitutions subs subs*))))
+
+(define (infer:primitive-call op args graph env subs)
+ (let* ((return-types (list (fresh-type-variable)))
+ (call-type (make-function-type
+ (map (lambda (arg) (lookup arg env)) args)
+ return-types))
+ (call-subs (unify call-type (primitive-operator-type op)))
+ (subs* (compose-substitutions subs call-subs)))
+ (values graph (substitute-types subs* return-types) subs*)))
+
+(define (infer* graph k types result-types env subs)
+ (define (type-mismatch-handler src e)
+ (let ((args (exception-args e)))
+ (match args
+ (((a ...) (b ...))
+ (seagull-type-error (format #f "expected ~a, got ~a"
+ (map type-name b) (map type-name a))
+ args src infer*))
+ ((a b)
+ (seagull-type-error (format #f "expected ~a, got ~a"
+ (type-name b) (type-name a))
+ args src infer*)))))
+ (define (with-error-handling src thunk)
+ (with-exception-handler (lambda (e) (type-mismatch-handler src e))
+ thunk
+ #:unwind? #t
+ #:unwind-for-type 'type-mismatch))
+ (define (infer-exp exp env)
+ (match (pk 'infer-exp exp)
+ (($ <cps-constant> _ type)
+ (values graph (list type) subs))
+ (($ <cps-values> vals)
+ (infer:values vals graph env subs))
+ (($ <cps-assignment> var val)
+ (infer:assignment var val graph env subs))
+ (($ <cps-primitive-call> (= lookup-primitive-operator op) args _)
+ (infer:primitive-call op args graph env subs))
+ (($ <cps-function> body)
+ (infer* graph body '() '() env subs))))
+ (define (infer-term term env)
+ (match (pk 'infer-term term)
+ ;; Regular continuation.
+ (($ <continue> src k* exp)
+ (with-error-handling src
+ (lambda ()
+ (let-values (((graph* types* subs*) (infer-exp exp env)))
+ (infer* graph* k* types* result-types env subs*)))))
+ ;; Function exit.
+ (($ <return> results)
+ (let* ((return-types (map (lambda (var) (lookup var env)) results))
+ ;; TODO: Unify with the return types of the current function.
+ ;;(expected-result-types result-types)
+ (subs* (unify return-types result-types)))
+ (values graph result-types (compose-substitutions subs subs*)))
+ (values graph types subs))
+ ;; Conditional branch.
+ (($ <branch> src name k-conseq k-alt)
+ (with-error-handling src
+ (lambda ()
+ ;; Type checking a branch goes like this:
+ ;; 1) Unify type of test variable with bool.
+ ;; 2) Infer types of consequent and alternate.
+ ;; 3) Unify types of consequent and alternate.
+ (let*-values (((subs1) (unify (lookup name env) bool))
+ ((subs2) (compose-substitutions subs subs1))
+ ((graph1 conseq-types subs3)
+ (infer* graph k-conseq types result-types env subs2))
+ ((graph2 alt-types subs4)
+ (infer* graph1 k-alt types result-types env subs3))
+ ((subs5) (unify conseq-types alt-types))
+ ((subs6) (compose-substitutions subs4 subs5)))
+ (values graph2 conseq-types subs6)))))))
+ (pk 'subs subs)
+ (pk 'env env)
+ (match (intmap-ref graph k)
+ (($ <arguments> names term _)
+ ;; Add newly defined variables to type environment, then infer
+ ;; the term in that environment.
+ (let*-values (((env*) (fold extend env names types))
+ ((graph* types* subs*) (infer-term term env*))
+ ((subs2) (unify types types*)))
+ (pk 'term-subs subs*)
+ (pk 'term-types 'before types 'after types*)
+ (values (intmap-replace graph* k (make-arguments names term types*))
+ types*
+ subs*
+ ;; (compose-substitutions subs* subs2)
+ )))
+ (($ <function-entry> src params results start return _)
+ (pk 'infer-function src params results start return)
+ ;; We don't know the type signature yet, so params and results
+ ;; are all type variables.
+ (let* ((param-types (fresh-type-variables params))
+ (result-types (fresh-type-variables results))
+ (func-type (make-function-type param-types result-types))
+ ;; Add params and results to type environment.
+ (env* (fold extend env
+ (append params results)
+ (append param-types result-types))))
+ ;; Infer types of function body using new type environment.
+ (define-values (graph* types* subs*)
+ (infer* graph start types result-types env* subs))
+ ;; Apply substitutions to function type.
+ ;; TODO: Handle polymorphism and type predicates.
+ (let* ((func-type* (substitute-type subs* func-type))
+ (func (make-function-entry src params results start return
+ func-type)))
+ (values (intmap-replace graph* k func)
+ (list func-type*)
+ subs*))))))
+
+(define (infer graph)
+ (infer* graph 0 '() '() (fresh-environment) no-substitutions))