• R/O
  • HTTP
  • SSH
  • HTTPS

Commit

Tags
No Tags

Frequently used words (click to add to your profile)

javac++androidlinuxc#windowsobjective-ccocoa誰得qtpythonphprubygameguibathyscaphec計画中(planning stage)翻訳omegatframeworktwitterdomtestvb.netdirectxゲームエンジンbtronarduinopreviewer

TensorFlowサンプルコード


Commit MetaInfo

Revision47c42744f7cdbefcb9dd7b1c09ee0e27076560c9 (tree)
Time2018-01-18 19:22:38
Authorhylom <hylom@hylo...>
Commiterhylom

Log Message

add tensorboard_test5.py

Change Summary

Incremental Difference

--- /dev/null
+++ b/tensorboard/tensorboard_test5.py
@@ -0,0 +1,167 @@
1+#!/usr/bin/env python
2+# -*- coding: utf-8 -*-
3+
4+import tensorflow as tf
5+
6+INPUT_SIZE = 15
7+W1_SIZE = 15
8+OUTPUT_SIZE = 10
9+
10+with tf.variable_scope('model') as scope:
11+
12+ # 入力
13+ x1 = tf.placeholder(dtype=tf.float32, name="x1")
14+ y = tf.placeholder(dtype=tf.float32, name="y")
15+
16+ # 第2層
17+ tf.set_random_seed(1234)
18+ W1 = tf.get_variable("W1",
19+ shape=[INPUT_SIZE, W1_SIZE],
20+ dtype=tf.float32,
21+ initializer=tf.random_normal_initializer(stddev=0.05))
22+ b1 = tf.get_variable("b1",
23+ shape=[W1_SIZE],
24+ dtype=tf.float32,
25+ initializer=tf.random_normal_initializer(stddev=0.05))
26+ x2 = tf.sigmoid(tf.matmul(x1, W1) + b1, name="x2")
27+
28+ # W1のヒストグラムを記録
29+ tf.summary.histogram('W1', W1)
30+
31+ # 第3層
32+ W2 = tf.get_variable("W2",
33+ shape=[W1_SIZE, OUTPUT_SIZE],
34+ dtype=tf.float32,
35+ initializer=tf.random_normal_initializer(stddev=0.05))
36+ b2 = tf.get_variable("b2",
37+ shape=[OUTPUT_SIZE],
38+ dtype=tf.float32,
39+ initializer=tf.random_normal_initializer(stddev=0.05))
40+ x3 = tf.nn.softmax(tf.matmul(x2, W2) + b2, name="x3")
41+
42+ # コスト関数
43+ cross_entropy = -tf.reduce_sum(y * tf.log(x3), name="cross_entropy")
44+ tf.summary.scalar('cross_entropy', cross_entropy)
45+
46+ # 正答率
47+ correct = tf.equal(tf.argmax(x3,1), tf.argmax(y,1), name="correct")
48+ accuracy = tf.reduce_mean(tf.cast(correct, "float"), name="accuracy")
49+ tf.summary.scalar('accuracy', accuracy)
50+
51+
52+ # 最適化アルゴリズムを定義
53+ global_step = tf.Variable(0, name='global_step', trainable=False)
54+ optimizer = tf.train.GradientDescentOptimizer(0.01, name="optimizer")
55+ minimize = optimizer.minimize(cross_entropy, global_step=global_step, name="minimize")
56+
57+ # 学習結果を保存するためのオブジェクトを用意
58+ saver = tf.train.Saver()
59+
60+
61+with tf.variable_scope('pipeline') as scope:
62+ ## データセットを読み込むためのパイプラインを作成する
63+ # リーダーオブジェクトを作成する
64+ reader = tf.TextLineReader()
65+
66+ # 読み込む対象のファイルを格納したキューを作成する
67+ file_queue = tf.train.string_input_producer(["digits_data.csv", "test_data.csv"])
68+
69+ # キューからデータを読み込む
70+ key, value = reader.read(file_queue)
71+
72+ # 読み込んだCSV型式データをデコードする
73+ # [[] for i in range(16)] は
74+ # [[], [], [], [], [], [], [], [],
75+ # [], [], [], [], [], [], [], []]に相当
76+ data = tf.decode_csv(value, record_defaults=[[] for i in range(16)])
77+
78+ # 10件のデータを読み出す
79+ # 10件ずつデータを読み出す
80+ # 第1カラム(data[0])はその文字が示す数だが、
81+ # ニューラルネットワークの出力は10要素の1次元テンソルとなる。
82+ # そのため、10×10の対角行列を作成し、そのdata[0]行目を取り出す操作を行うことで
83+ # 1次元テンソルに変換する。dataは浮動小数点小数型なので、このとき
84+ # int32型にキャストして使用する
85+ data_x, data_y, y_value = tf.train.batch([
86+ tf.stack(data[1:]),
87+ tf.reshape(tf.slice(tf.eye(10), [tf.cast(data[0], tf.int32), 0], [1, 10]), [10]),
88+ tf.cast(data[0], tf.int64),
89+ ], 10)
90+
91+# セッションの作成
92+sess = tf.Session()
93+
94+# 変数の初期化を実行する
95+sess.run(tf.global_variables_initializer())
96+
97+# 学習結果を保存したファイルが存在するかを確認し、
98+# 存在していればそれを読み出す
99+latest_filename = tf.train.latest_checkpoint("./")
100+if latest_filename:
101+ print("load saved model {}".format(latest_filename))
102+ saver.restore(sess, latest_filename)
103+
104+# サマリを取得するための処理
105+summary_op = tf.summary.merge_all()
106+summary_writer = tf.summary.FileWriter('data', graph=sess.graph)
107+
108+
109+# コーディネータの作成
110+coord = tf.train.Coordinator()
111+
112+# キューの開始
113+threads = tf.train.start_queue_runners(sess=sess, coord=coord)
114+
115+# ファイルからのデータの読み出し
116+# 1回目のデータ読み込み。1つ目のファイルから10件のデータが読み込まれる
117+# 1つ目のファイルには10件のデータがあるので、これで全データが読み込まれる
118+dataset_x, dataset_y, values_y = sess.run([data_x, data_y, y_value])
119+
120+# 2回目のデータ読み込み。1つ目のファイルのデータはすべて読み出したので、
121+# 続けて2つ目のファイルから読み込みが行われる。
122+testdata_x, testdata_y, testvalues_y = sess.run([data_x, data_y, y_value])
123+
124+# 学習を開始
125+for i in range(100):
126+ for j in range(100):
127+ _, summary = sess.run([minimize, summary_op], {x1: dataset_x, y: dataset_y})
128+ print("CROSS ENTROPY:", sess.run(cross_entropy, {x1: dataset_x, y: dataset_y}))
129+ summary_writer.add_summary(summary, global_step=tf.train.global_step(sess, global_step))
130+
131+# 結果を保存する
132+save_path = saver.save(sess, "./model", global_step=tf.train.global_step(sess, global_step))
133+print("Model saved to {}".format(save_path))
134+
135+## 結果の出力
136+# 出力テンソルの中でもっとも値が大きいもののインデックスが
137+# 正答と等しいかどうかを計算する
138+y_value = tf.placeholder(dtype=tf.int64)
139+correct = tf.equal(tf.argmax(x3,1), y_value)
140+accuracy = tf.reduce_mean(tf.cast(correct, "float"))
141+
142+# 学習に使用したデータを入力した場合の
143+# ニューラルネットワークの出力を表示
144+print("----result----")
145+print("raw output:")
146+print(sess.run(x3,feed_dict={x1: dataset_x}))
147+print("answers:", sess.run(tf.argmax(x3, 1), feed_dict={x1: dataset_x}))
148+
149+# このときの正答率を出力
150+print("accuracy:", sess.run(accuracy, feed_dict={x1: dataset_x, y_value: values_y}))
151+
152+
153+# テスト用データを入力した場合の
154+# ニューラルネットワークの出力を表示
155+print("----test----")
156+print("raw output:")
157+print(sess.run(x3,feed_dict={x1: testdata_x}))
158+print("answers:", sess.run(tf.argmax(x3, 1), feed_dict={x1: testdata_x}))
159+
160+# このときの正答率を出力
161+print("accuracy:", sess.run(accuracy, feed_dict={x1: testdata_x, y_value: testvalues_y}))
162+
163+
164+
165+# キューの終了
166+coord.request_stop()
167+coord.join(threads)