Emergent generative agents
Revision | 710ae6e7774980d003893e18f50bdabc95a00f6f (tree) |
---|---|
Time | 2023-05-12 04:28:12 |
Author | Corbin <cds@corb...> |
Commiter | Corbin |
Parameterize model path.
So that it can run on other peoples' machines!
@@ -17,9 +17,7 @@ from twisted.internet.threads import deferToThread | ||
17 | 17 | from twisted.words.protocols.irc import IRCClient |
18 | 18 | |
19 | 19 | from common import irc_line, Timer, SentenceIndex, breakAt |
20 | -# from gens.camelid import CamelidGen | |
21 | 20 | from gens.mawrkov import MawrkovGen |
22 | -# from gens.trans import Flavor, HFGen, SentenceEmbed | |
23 | 21 | from gens.trans import SentenceEmbed |
24 | 22 | |
25 | 23 | build_traits = " + ".join |
@@ -32,7 +30,8 @@ def load_character(path): | ||
32 | 30 | |
33 | 31 | MAX_NEW_TOKENS = 128 |
34 | 32 | print("~ Initializing mawrkov adapter…") |
35 | -gen = MawrkovGen(MAX_NEW_TOKENS) | |
33 | +model_path = sys.argv[1] | |
34 | +gen = MawrkovGen(model_path, MAX_NEW_TOKENS) | |
36 | 35 | # Need to protect per-gen data structures in C. |
37 | 36 | genLock = Lock() |
38 | 37 | GiB = 1024 ** 3 |
@@ -72,16 +71,17 @@ class Mind: | ||
72 | 71 | # Newlines are added here. |
73 | 72 | self.logits, self.state = gen.feedForward(gen.tokenize(s + "\n"), |
74 | 73 | self.logits, self.state) |
75 | - print("~ Write:", s) | |
74 | + print(s) | |
76 | 75 | |
77 | 76 | def complete(self, s): |
78 | 77 | with genLock: |
79 | 78 | completion, self.logits, self.state = gen.complete(s, self.logits, self.state) |
80 | - print("~ Predicted:", completion) | |
79 | + print("«", completion, "»") | |
81 | 80 | return completion |
82 | 81 | |
83 | 82 | def infer(self, tag, prefix): |
84 | 83 | def cb(): |
84 | + print(prefix) | |
85 | 85 | d = self.switchTag(tag) |
86 | 86 | d.addCallback(lambda _: deferToThread(self.complete, prefix)) |
87 | 87 | return d |
@@ -126,7 +126,6 @@ class ChainOfThoughts(Agent): | ||
126 | 126 | thoughts = self.index.search(s, 2) |
127 | 127 | for thought in thoughts: |
128 | 128 | if thought not in self.recentThoughts: |
129 | - print("~ New relevant thought:", thought) | |
130 | 129 | self.recentThoughts.append(thought) |
131 | 130 | self.broadcast(thought) |
132 | 131 |
@@ -178,7 +177,6 @@ class IRCAgent(Agent, IRCClient): | ||
178 | 177 | user = user.split("!", 1)[0] |
179 | 178 | self.broadcast(irc_line(datetime.now(), channel, user, line)) |
180 | 179 | if self.nickname in line: |
181 | - print("~ Ping on IRC:", self.nickname) | |
182 | 180 | d = self.mind.infer("irc", self.prefix(channel)) |
183 | 181 | |
184 | 182 | @d.addCallback |
@@ -205,7 +203,7 @@ def go(): | ||
205 | 203 | clock = Clock() |
206 | 204 | LoopingCall(clock.go).start(300.0, now=False) |
207 | 205 | |
208 | - for logpath in sys.argv[1:]: | |
206 | + for logpath in sys.argv[2:]: | |
209 | 207 | character = load_character(logpath) |
210 | 208 | title = character["title"] |
211 | 209 | firstStatement = f"I am {title}." |
@@ -4,17 +4,15 @@ from common import Timer | ||
4 | 4 | |
5 | 5 | from llama_cpp import Llama, llama_cpp |
6 | 6 | |
7 | -MODEL = "/home/simpson/models/export/llama/30b-4bit.bin" | |
8 | -MODEL_SIZE = os.stat(MODEL).st_size * 3 // 2 | |
9 | - | |
10 | 7 | class CamelidGen: |
11 | 8 | model_name = "LLaMA?" |
12 | 9 | model_arch = "LLaMA" |
13 | - def __init__(self, max_new_tokens): | |
14 | - self.llama = Llama(MODEL, n_ctx=1024) | |
10 | + def __init__(self, model_path, max_new_tokens): | |
11 | + self.llama = Llama(model_path, n_ctx=1024) | |
12 | + self.model_size = os.stat(model_path).st_size * 3 // 2 | |
15 | 13 | self.max_new_tokens = max_new_tokens |
16 | 14 | |
17 | - def footprint(self): return MODEL_SIZE | |
15 | + def footprint(self): return self.model_size | |
18 | 16 | def contextLength(self): return llama_cpp.llama_n_ctx(self.llama.ctx) |
19 | 17 | |
20 | 18 | # XXX doesn't work? |
@@ -31,7 +29,8 @@ class CamelidGen: | ||
31 | 29 | class CamelidEmbed: |
32 | 30 | # XXX is that still the case? |
33 | 31 | embedding_width = 4096 |
34 | - def __init__(self): self.llama = Llama(MODEL, embedding=True) | |
32 | + def __init__(self, model_path): | |
33 | + self.llama = Llama(model_path, embedding=True) | |
35 | 34 | |
36 | 35 | def embed(self, s): |
37 | 36 | with Timer("embedding"): |
@@ -4,8 +4,6 @@ import sys | ||
4 | 4 | |
5 | 5 | import tokenizers |
6 | 6 | |
7 | -from common import Timer | |
8 | - | |
9 | 7 | # Monkey-patch to get rwkv available. |
10 | 8 | RWKV = "@RWKV@" |
11 | 9 | RWKV_PATH = os.path.join(RWKV, "bin") |
@@ -22,8 +20,6 @@ rwkv_cpp_shared_library = bare_import(RWKV_PATH, "rwkv_cpp_shared_library") | ||
22 | 20 | rwkv_cpp_model = bare_import(RWKV_PATH, "rwkv_cpp_model") |
23 | 21 | sampling = bare_import(RWKV_PATH, "sampling") |
24 | 22 | |
25 | -MODEL = "/home/simpson/models/export/rwkv/rwkv-pile-14b-Q4_3.bin" | |
26 | -MODEL_SIZE = os.stat(MODEL).st_size * 3 // 2 | |
27 | 23 | TOKENIZER_PATH = os.path.join(RWKV, "share", "20B_tokenizer.json") |
28 | 24 | |
29 | 25 | # Upstream recommends temp 0.7, top_p 0.5 |
@@ -31,15 +27,18 @@ TEMPERATURE = 0.8 | ||
31 | 27 | TOP_P = 0.8 |
32 | 28 | |
33 | 29 | class MawrkovGen: |
30 | + # XXX might be wrong | |
34 | 31 | model_name = "The Pile (14B params, 4-bit quantized)" |
35 | 32 | model_arch = "RWKV" |
36 | - def __init__(self, max_new_tokens): | |
33 | + def __init__(self, model_path, max_new_tokens): | |
37 | 34 | self.max_new_tokens = max_new_tokens |
35 | + self.model_size = os.stat(model_path).st_size * 3 // 2 | |
38 | 36 | self.tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH) |
39 | 37 | self.lib = rwkv_cpp_shared_library.load_rwkv_shared_library() |
40 | - self.model = rwkv_cpp_model.RWKVModel(self.lib, MODEL) | |
38 | + self.model = rwkv_cpp_model.RWKVModel(self.lib, model_path) | |
41 | 39 | |
42 | - def footprint(self): return MODEL_SIZE | |
40 | + # XXX wrong | |
41 | + def footprint(self): return self.model_size | |
43 | 42 | def contextLength(self): return 8192 |
44 | 43 | def tokenize(self, s): return self.tokenizer.encode(s).ids |
45 | 44 | def countTokens(self, s): return len(self.tokenize(s)) |
@@ -52,10 +51,9 @@ class MawrkovGen: | ||
52 | 51 | def complete(self, s, logits, state): |
53 | 52 | logits, state = self.feedForward(self.tokenize(s), logits, state) |
54 | 53 | tokens = [] |
55 | - with Timer("completion"): | |
56 | - for i in range(self.max_new_tokens): | |
57 | - token = sampling.sample_logits(logits, TEMPERATURE, TOP_P) | |
58 | - tokens.append(token) | |
59 | - logits, state = self.feedForward([token], logits, state) | |
60 | - if "\n" in self.tokenizer.decode([token]): break | |
54 | + for i in range(self.max_new_tokens): | |
55 | + token = sampling.sample_logits(logits, TEMPERATURE, TOP_P) | |
56 | + tokens.append(token) | |
57 | + logits, state = self.feedForward([token], logits, state) | |
58 | + if "\n" in self.tokenizer.decode([token]): break | |
61 | 59 | return self.tokenizer.decode(tokens).split("\n", 1)[0], logits, state |