100个汉字,放在data目录下。直接将下述文件和data存在同一个目录下运行即可。
关键参数:
run_mode = "train" 训练模型用,修改为validation 表示验证100张图片的预测精度,修改为inference表示预测 './data/00098/102544.png'这个图片手写识别结果,返回top3。
charset_size = 100 表示汉字数目。如果是全量数据,则为3755.
代码参考了:https://github.com/burness/tensorflow-101/blob/master/chinese_hand_write_rec/src/chinese_rec.py
其中加入:(1)图像随机左右旋转30度特性 (2)断点续传进行训练(3)为了达到更高精度,加入了一个卷积层,见https://github.com/AmemiyaYuko/HandwrittenChineseCharacterRecognition
import tensorflow as tfimport osimport randomimport mathimport tensorflow.contrib.slim as slimimport timeimport loggingimport numpy as npimport picklefrom PIL import Image logger = logging.getLogger('Training a chinese write char recognition')logger.setLevel(logging.INFO)# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')ch = logging.StreamHandler()ch.setLevel(logging.INFO)logger.addHandler(ch)run_mode = "train"charset_size = 100 # 3755max_steps = 12002save_steps = 2000 """# for online 3755 words trainingcheckpoint_dir = '/aiml/dfs/checkpoint/'train_data_dir = '/aiml/data/train/'test_data_dir = '/aiml/data/test/'log_dir = '/aiml/dfs/'"""checkpoint_dir = './checkpoint2/'train_data_dir = './data/'test_data_dir = './data/'log_dir = './'tf.app.flags.DEFINE_string('mode', run_mode, 'Running mode. One of {"train", "valid", "test"}')tf.app.flags.DEFINE_boolean('random_flip_up_down', True, "Whether to random flip up down")tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast") tf.app.flags.DEFINE_integer('charset_size', charset_size, "Choose the first `charset_size` character to conduct our experiment.")tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")tf.app.flags.DEFINE_integer('max_steps', max_steps, 'the max training steps ')tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")tf.app.flags.DEFINE_integer('save_steps', save_steps, "the steps to save") tf.app.flags.DEFINE_string('checkpoint_dir', checkpoint_dir, 'the checkpoint dir')tf.app.flags.DEFINE_string('train_data_dir', train_data_dir, 'the train dataset dir')tf.app.flags.DEFINE_string('test_data_dir', test_data_dir, 'the test dataset dir')tf.app.flags.DEFINE_string('log_dir', log_dir, 'the logging dir') ############################### resume trainingtf.app.flags.DEFINE_boolean('restore', True, 'whether to restore from checkpoint')##############################tf.app.flags.DEFINE_boolean('epoch', 10, 'Number of epoches')tf.app.flags.DEFINE_boolean('batch_size', 128, 'Validation batch size')FLAGS = tf.app.flags.FLAGS class DataIterator: def __init__(self, data_dir): # Set FLAGS.charset_size to a small value if available computation power is limited. truncate_path = data_dir + ('%05d' % FLAGS.charset_size) print(truncate_path) self.image_names = [] for root, sub_folder, file_list in os.walk(data_dir): if root < truncate_path: self.image_names += [os.path.join(root, file_path) for file_path in file_list] random.shuffle(self.image_names) self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names] @property def size(self): return len(self.labels) @staticmethod def data_augmentation(images): if FLAGS.random_flip_up_down: # images = tf.image.random_flip_up_down(images) images = tf.contrib.image.rotate(images, random.randint(0, 30) * math.pi / 180, interpolation='BILINEAR') if FLAGS.random_brightness: images = tf.image.random_brightness(images, max_delta=0.3) if FLAGS.random_contrast: images = tf.image.random_contrast(images, 0.8, 1.2) return images def input_pipeline(self, batch_size, num_epochs=None, aug=False): images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string) labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64) input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs) labels = input_queue[1] images_content = tf.read_file(input_queue[0]) images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32) if aug: images = self.data_augmentation(images) new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32) images = tf.image.resize_images(images, new_size) image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000, min_after_dequeue=10000) return image_batch, label_batch def build_graph(top_k): # with tf.device('/cpu:0'): keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob') images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch') labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch') conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1') max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME') conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2') max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME') conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3') max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME') conv_4 = slim.conv2d(max_pool_3, 512, [3, 3], [2, 2], scope="conv4", padding="SAME") max_pool_4 = slim.max_pool2d(conv_4, [2, 2], [2, 2], padding="SAME") flatten = slim.flatten(max_pool_4) fc1 = slim.fully_connected(slim.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1') logits = slim.fully_connected(slim.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2') # logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc') loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)) accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False) rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True) train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step) probabilities = tf.nn.softmax(logits) tf.summary.scalar('loss', loss) tf.summary.scalar('accuracy', accuracy) merged_summary_op = tf.summary.merge_all() predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k) accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32)) return {'images': images, 'labels': labels, 'keep_prob': keep_prob, 'top_k': top_k, 'global_step': global_step, 'train_op': train_op, 'loss': loss, 'accuracy': accuracy, 'accuracy_top_k': accuracy_in_top_k, 'merged_summary_op': merged_summary_op, 'predicted_distribution': probabilities, 'predicted_index_top_k': predicted_index_top_k, 'predicted_val_top_k': predicted_val_top_k} def train(): print('Begin training') train_feeder = DataIterator(FLAGS.train_data_dir) test_feeder = DataIterator(FLAGS.test_data_dir) with tf.Session() as sess: train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True) test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size) graph = build_graph(top_k=1) sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val') start_step = 0 if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print("restore from the checkpoint {0}".format(ckpt)) start_step += int(ckpt.split('-')[-1]) logger.info(':::Training Start:::') try: while not coord.should_stop(): start_time = time.time() train_images_batch, train_labels_batch = sess.run([train_images, train_labels]) feed_dict = {graph['images']: train_images_batch, graph['labels']: train_labels_batch, graph['keep_prob']: 0.8} _, loss_val, train_summary, step = sess.run( [graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']], feed_dict=feed_dict) train_writer.add_summary(train_summary, step) end_time = time.time() logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val)) if step > FLAGS.max_steps: break if step % FLAGS.eval_steps == 1: test_images_batch, test_labels_batch = sess.run([test_images, test_labels]) feed_dict = {graph['images']: test_images_batch, graph['labels']: test_labels_batch, graph['keep_prob']: 1.0} accuracy_test, test_summary = sess.run( [graph['accuracy'], graph['merged_summary_op']], feed_dict=feed_dict) test_writer.add_summary(test_summary, step) logger.info('===============Eval a batch=======================') logger.info('the step {0} test accuracy: {1}' .format(step, accuracy_test)) logger.info('===============Eval a batch=======================') if step % FLAGS.save_steps == 1: logger.info('Save the ckpt of {0}'.format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step']) except tf.errors.OutOfRangeError: logger.info('==================Train Finished================') saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step']) finally: coord.request_stop() coord.join(threads) def validation(): print('validation') test_feeder = DataIterator(FLAGS.test_data_dir) final_predict_val = [] final_predict_index = [] groundtruth = [] with tf.Session() as sess: test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1) graph = build_graph(top_k=3) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # initialize test_feeder's inside state coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print("restore from the checkpoint {0}".format(ckpt)) print(':::Start validation:::') try: i = 0 acc_top_1, acc_top_k = 0.0, 0.0 while not coord.should_stop(): i += 1 start_time = time.time() test_images_batch, test_labels_batch = sess.run([test_images, test_labels]) feed_dict = {graph['images']: test_images_batch, graph['labels']: test_labels_batch, graph['keep_prob']: 1.0} batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'], graph['predicted_val_top_k'], graph['predicted_index_top_k'], graph['accuracy'], graph['accuracy_top_k']], feed_dict=feed_dict) final_predict_val += probs.tolist() final_predict_index += indices.tolist() groundtruth += batch_labels.tolist() acc_top_1 += acc_1 acc_top_k += acc_k end_time = time.time() logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)" .format(i, end_time - start_time, acc_1, acc_k)) except tf.errors.OutOfRangeError: logger.info('==================Validation Finished================') acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k)) finally: coord.request_stop() coord.join(threads) return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth} def inference(image): print('inference') temp_image = Image.open(image).convert('L') temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS) temp_image = np.asarray(temp_image) / 255.0 temp_image = temp_image.reshape([-1, 64, 64, 1]) with tf.Session() as sess: logger.info('========start inference============') # images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1]) # Pass a shadow label 0. This label will not affect the computation graph. graph = build_graph(top_k=3) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']], feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0}) return predict_val, predict_index def main(_): print(FLAGS.mode) if FLAGS.mode == "train": train() elif FLAGS.mode == 'validation': dct = validation() result_file = 'result.dict' logger.info('Write result into {0}'.format(result_file)) with open(result_file, 'wb') as f: pickle.dump(dct, f) logger.info('Write file ends') elif FLAGS.mode == 'inference': image_path = './data/00098/102544.png' final_predict_val, final_predict_index = inference(image_path) logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index, final_predict_val)) if __name__ == "__main__": tf.app.run()