• R/O
  • HTTP
  • SSH
  • HTTPS

Tags
No Tags

Frequently used words (click to add to your profile)

javac++androidlinuxc#objective-cqtwindows誰得cocoapythonphprubygameguibathyscaphec翻訳omegat計画中(planning stage)frameworktwittertestdomvb.netdirectxbtronarduinopreviewerゲームエンジン

A categorical programming language


File Info

Rev. de203222d444a26d5f4b5ea9b4b6884ba50cb824
Size 6,602 bytes
Time 2022-11-28 07:34:34
Author Corbin
Log Message

Finish switching over to PureScript.

This is very messy and I'm not really pleased with it. The overall
speedup that I was hoping for? Not there!

Content

from rpython.rlib.objectmodel import compute_hash, r_dict, specialize

from cammylib.arrows import Given, buildUnary, buildCompound, unaryFunctors

BASIS_ATOMS = tuple(unaryFunctors.keys())

BASIS_FUNCTORS = (
    "comp", "pair", "case", "curry", "uncurry", "pr", "fold",
)


class SymbolRewriter(object):
    """
    A visitor which traverses an S-expression and replaces symbols without
    changing any structure.
    """

    def rewrite(self, symbol):
        assert False


def equalSexp(a1, a2):
    if a1 is a2:
        return True
    elif isinstance(a1, Atom) and isinstance(a2, Atom):
        return a1.symbol == a2.symbol
    elif isinstance(a1, Functor) and isinstance(a2, Functor):
        if a1.constructor != a2.constructor:
            return False
        if len(a1.arguments) != len(a2.arguments):
            return False
        for i, arg in enumerate(a1.arguments):
            if not equalSexp(arg, a2.arguments[i]):
                return False
        return True
    elif isinstance(a1, Hole) and isinstance(a2, Hole):
        return a1.index == a2.index
    return False

def hashSexp(sexp):
    return sexp.hash()

@specialize.call_location()
def sexpDict():
    return r_dict(equalSexp, hashSexp, simple_hash_eq=True)

class SexpCache(object):
    def __init__(self):
        self.d = sexpDict()

    def intern(self, sexp):
        return self.d.setdefault(sexp, sexp)

    def Atom(self, symbol):
        return self.intern(Atom(symbol))

    def Functor(self, constructor, arguments):
        args = [self.intern(arg) for arg in arguments]
        return self.intern(Functor(constructor, args))

    def Hole(self, index):
        return self.intern(Hole(index))


class SExp(object):
    "An S-expression."

class Atom(SExp):
    "An S-expression atom."

    _immutable_fields_ = "symbol",

    def __init__(self, symbol):
        self.symbol = symbol

    def hash(self):
        return compute_hash(self.symbol)

    def asStr(self):
        return self.symbol

    def symbols(self, s):
        s[self.symbol] = None

    def substitute(self, args):
        return self

    def rewriteSymbols(self, rewriter):
        return sexp.Atom(rewriter.rewrite(self.symbol))

    def canonicalize(self, hive):
        if self.symbol in BASIS_ATOMS:
            return self
        else:
            return hive.load(self.symbol)

    def occurs(self, index):
        return False

    def extractType(self, extractor, formatter, outerPrecedence):
        return formatter.formatN if self.symbol == "N" else self.symbol

    def buildArrow(self):
        return buildUnary(self.symbol)

    def countHoles(self):
        return 0

# Precedence levels for operator-style notation.
PRECEDENCE = {
    "hom": 3,
    "pair": 2,
    "sum": 1,
}

class Functor(SExp):
    "A list of S-expressions with a distinguished head."

    _immutable_fields_ = "constructor", "arguments[:]"

    def __init__(self, constructor, arguments):
        self.constructor = constructor
        self.arguments = arguments

    def hash(self):
        rv = compute_hash(self.constructor)
        for arg in self.arguments:
            rv ^= arg.hash()
        return rv

    def asStr(self):
        args = " ".join([arg.asStr() for arg in self.arguments])
        return "(%s %s)" % (self.constructor, args)

    def symbols(self, s):
        s[self.constructor] = None
        for arg in self.arguments:
            arg.symbols(s)

    def substitute(self, args):
        return sexp.Functor(self.constructor,
                [arg.substitute(args) for arg in self.arguments])

    def rewriteSymbols(self, rewriter):
        return sexp.Functor(rewriter.rewrite(self.constructor),
                [arg.rewriteSymbols(rewriter) for arg in self.arguments])

    def canonicalize(self, hive):
        args = [arg.canonicalize(hive) for arg in self.arguments]
        if self.constructor in BASIS_FUNCTORS:
            return sexp.Functor(self.constructor, args)
        else:
            functor = hive.load(self.constructor)
            return functor.substitute(args)

    def occurs(self, index):
        for arg in self.arguments:
            if arg.occurs(index):
                return True
        return False

    def extractType(self, extractor, formatter, outerPrecedence):
        innerPrecedence = PRECEDENCE.get(self.constructor, 0)
        args = [extractor.extractWithPrecedence(unhole(arg), innerPrecedence)
                for arg in self.arguments]
        if self.constructor == "hom":
            rv = formatter.formatHom(args[0], args[1])
        elif self.constructor == "pair":
            rv = formatter.formatPair(args[0], args[1])
        elif self.constructor == "sum":
            rv = formatter.formatSum(args[0], args[1])
        elif self.constructor == "list":
            rv = formatter.formatList(args[0])
        else:
            assert False, "whoopsie-doodle"
        # NB: This is traditionally >, not >=
        # but that forgets required strictness and would be confusing to read
        if (outerPrecedence >= innerPrecedence and
            not formatter.resetsBracketsFor(self.constructor)):
            rv = formatter.parenthesize(rv)
        return rv

    def buildArrow(self):
        args = [arg.buildArrow() for arg in self.arguments]
        return buildCompound(self.constructor, args)

    def countHoles(self):
        rv = 0
        for arg in self.arguments:
            rv = max(rv, arg.countHoles())
        return rv

class Hole(SExp):
    "A hole where an S-expression could be."

    _immutable_fields_ = "index",

    def __init__(self, index):
        self.index = index

    def hash(self):
        return self.index

    def asStr(self):
        return "@%d" % self.index

    def symbols(self, s):
        pass

    def substitute(self, args):
        return args[self.index]

    def rewriteSymbols(self, rewriter):
        return self

    def canonicalize(self, hive):
        return self

    def occurs(self, index):
        return self.index == index

    def extractType(self, extractor, formatter, outerPrecedence):
        return extractor.findTypeAlias(self.index)

    def buildArrow(self):
        return Given(self.index)

    def countHoles(self):
        return self.index + 1

def unhole(sexp):
    assert isinstance(sexp, Hole), "implementation error"
    return sexp.index


# One global cache for all S-expressions.
sexp = SexpCache()

# Unit tests. Must pass, or won't compile.
assert sexp.Atom("id") is sexp.Atom("id")
assert sexp.Functor("curry", [sexp.Atom("id")]) is sexp.Functor("curry", [sexp.Atom("id")])