• R/O
  • HTTP
  • SSH
  • HTTPS

Commit

Tags
No Tags

Frequently used words (click to add to your profile)

javac++androidlinuxc#windowsobjective-ccocoa誰得qtpythonphprubygameguibathyscaphec計画中(planning stage)翻訳omegatframeworktwitterdomtestvb.netdirectxゲームエンジンbtronarduinopreviewer

Emergent generative agents


Commit MetaInfo

Revision710ae6e7774980d003893e18f50bdabc95a00f6f (tree)
Time2023-05-12 04:28:12
AuthorCorbin <cds@corb...>
CommiterCorbin

Log Message

Parameterize model path.

So that it can run on other peoples' machines!

Change Summary

Incremental Difference

--- a/src/agent.py
+++ b/src/agent.py
@@ -17,9 +17,7 @@ from twisted.internet.threads import deferToThread
1717 from twisted.words.protocols.irc import IRCClient
1818
1919 from common import irc_line, Timer, SentenceIndex, breakAt
20-# from gens.camelid import CamelidGen
2120 from gens.mawrkov import MawrkovGen
22-# from gens.trans import Flavor, HFGen, SentenceEmbed
2321 from gens.trans import SentenceEmbed
2422
2523 build_traits = " + ".join
@@ -32,7 +30,8 @@ def load_character(path):
3230
3331 MAX_NEW_TOKENS = 128
3432 print("~ Initializing mawrkov adapter…")
35-gen = MawrkovGen(MAX_NEW_TOKENS)
33+model_path = sys.argv[1]
34+gen = MawrkovGen(model_path, MAX_NEW_TOKENS)
3635 # Need to protect per-gen data structures in C.
3736 genLock = Lock()
3837 GiB = 1024 ** 3
@@ -72,16 +71,17 @@ class Mind:
7271 # Newlines are added here.
7372 self.logits, self.state = gen.feedForward(gen.tokenize(s + "\n"),
7473 self.logits, self.state)
75- print("~ Write:", s)
74+ print(s)
7675
7776 def complete(self, s):
7877 with genLock:
7978 completion, self.logits, self.state = gen.complete(s, self.logits, self.state)
80- print("~ Predicted:", completion)
79+ print("«", completion, "»")
8180 return completion
8281
8382 def infer(self, tag, prefix):
8483 def cb():
84+ print(prefix)
8585 d = self.switchTag(tag)
8686 d.addCallback(lambda _: deferToThread(self.complete, prefix))
8787 return d
@@ -126,7 +126,6 @@ class ChainOfThoughts(Agent):
126126 thoughts = self.index.search(s, 2)
127127 for thought in thoughts:
128128 if thought not in self.recentThoughts:
129- print("~ New relevant thought:", thought)
130129 self.recentThoughts.append(thought)
131130 self.broadcast(thought)
132131
@@ -178,7 +177,6 @@ class IRCAgent(Agent, IRCClient):
178177 user = user.split("!", 1)[0]
179178 self.broadcast(irc_line(datetime.now(), channel, user, line))
180179 if self.nickname in line:
181- print("~ Ping on IRC:", self.nickname)
182180 d = self.mind.infer("irc", self.prefix(channel))
183181
184182 @d.addCallback
@@ -205,7 +203,7 @@ def go():
205203 clock = Clock()
206204 LoopingCall(clock.go).start(300.0, now=False)
207205
208- for logpath in sys.argv[1:]:
206+ for logpath in sys.argv[2:]:
209207 character = load_character(logpath)
210208 title = character["title"]
211209 firstStatement = f"I am {title}."
--- a/src/gens/camelid.py
+++ b/src/gens/camelid.py
@@ -4,17 +4,15 @@ from common import Timer
44
55 from llama_cpp import Llama, llama_cpp
66
7-MODEL = "/home/simpson/models/export/llama/30b-4bit.bin"
8-MODEL_SIZE = os.stat(MODEL).st_size * 3 // 2
9-
107 class CamelidGen:
118 model_name = "LLaMA?"
129 model_arch = "LLaMA"
13- def __init__(self, max_new_tokens):
14- self.llama = Llama(MODEL, n_ctx=1024)
10+ def __init__(self, model_path, max_new_tokens):
11+ self.llama = Llama(model_path, n_ctx=1024)
12+ self.model_size = os.stat(model_path).st_size * 3 // 2
1513 self.max_new_tokens = max_new_tokens
1614
17- def footprint(self): return MODEL_SIZE
15+ def footprint(self): return self.model_size
1816 def contextLength(self): return llama_cpp.llama_n_ctx(self.llama.ctx)
1917
2018 # XXX doesn't work?
@@ -31,7 +29,8 @@ class CamelidGen:
3129 class CamelidEmbed:
3230 # XXX is that still the case?
3331 embedding_width = 4096
34- def __init__(self): self.llama = Llama(MODEL, embedding=True)
32+ def __init__(self, model_path):
33+ self.llama = Llama(model_path, embedding=True)
3534
3635 def embed(self, s):
3736 with Timer("embedding"):
--- a/src/gens/mawrkov.py
+++ b/src/gens/mawrkov.py
@@ -4,8 +4,6 @@ import sys
44
55 import tokenizers
66
7-from common import Timer
8-
97 # Monkey-patch to get rwkv available.
108 RWKV = "@RWKV@"
119 RWKV_PATH = os.path.join(RWKV, "bin")
@@ -22,8 +20,6 @@ rwkv_cpp_shared_library = bare_import(RWKV_PATH, "rwkv_cpp_shared_library")
2220 rwkv_cpp_model = bare_import(RWKV_PATH, "rwkv_cpp_model")
2321 sampling = bare_import(RWKV_PATH, "sampling")
2422
25-MODEL = "/home/simpson/models/export/rwkv/rwkv-pile-14b-Q4_3.bin"
26-MODEL_SIZE = os.stat(MODEL).st_size * 3 // 2
2723 TOKENIZER_PATH = os.path.join(RWKV, "share", "20B_tokenizer.json")
2824
2925 # Upstream recommends temp 0.7, top_p 0.5
@@ -31,15 +27,18 @@ TEMPERATURE = 0.8
3127 TOP_P = 0.8
3228
3329 class MawrkovGen:
30+ # XXX might be wrong
3431 model_name = "The Pile (14B params, 4-bit quantized)"
3532 model_arch = "RWKV"
36- def __init__(self, max_new_tokens):
33+ def __init__(self, model_path, max_new_tokens):
3734 self.max_new_tokens = max_new_tokens
35+ self.model_size = os.stat(model_path).st_size * 3 // 2
3836 self.tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH)
3937 self.lib = rwkv_cpp_shared_library.load_rwkv_shared_library()
40- self.model = rwkv_cpp_model.RWKVModel(self.lib, MODEL)
38+ self.model = rwkv_cpp_model.RWKVModel(self.lib, model_path)
4139
42- def footprint(self): return MODEL_SIZE
40+ # XXX wrong
41+ def footprint(self): return self.model_size
4342 def contextLength(self): return 8192
4443 def tokenize(self, s): return self.tokenizer.encode(s).ids
4544 def countTokens(self, s): return len(self.tokenize(s))
@@ -52,10 +51,9 @@ class MawrkovGen:
5251 def complete(self, s, logits, state):
5352 logits, state = self.feedForward(self.tokenize(s), logits, state)
5453 tokens = []
55- with Timer("completion"):
56- for i in range(self.max_new_tokens):
57- token = sampling.sample_logits(logits, TEMPERATURE, TOP_P)
58- tokens.append(token)
59- logits, state = self.feedForward([token], logits, state)
60- if "\n" in self.tokenizer.decode([token]): break
54+ for i in range(self.max_new_tokens):
55+ token = sampling.sample_logits(logits, TEMPERATURE, TOP_P)
56+ tokens.append(token)
57+ logits, state = self.feedForward([token], logits, state)
58+ if "\n" in self.tokenizer.decode([token]): break
6159 return self.tokenizer.decode(tokens).split("\n", 1)[0], logits, state