博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
cnn handwrite使用原生的TensorFlow进行预测
阅读量:6215 次
发布时间:2019-06-21

本文共 15362 字,大约阅读时间需要 51 分钟。

100个汉字,放在data目录下。直接将下述文件和data存在同一个目录下运行即可。

关键参数:

run_mode = "train" 训练模型用,修改为validation 表示验证100张图片的预测精度,修改为inference表示预测 './data/00098/102544.png'这个图片手写识别结果,返回top3。

charset_size = 100 表示汉字数目。如果是全量数据,则为3755.

代码参考了:https://github.com/burness/tensorflow-101/blob/master/chinese_hand_write_rec/src/chinese_rec.py

其中加入:(1)图像随机左右旋转30度特性 (2)断点续传进行训练(3)为了达到更高精度,加入了一个卷积层,见https://github.com/AmemiyaYuko/HandwrittenChineseCharacterRecognition

import tensorflow as tfimport osimport randomimport mathimport tensorflow.contrib.slim as slimimport timeimport loggingimport numpy as npimport picklefrom PIL import Image  logger = logging.getLogger('Training a chinese write char recognition')logger.setLevel(logging.INFO)# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')ch = logging.StreamHandler()ch.setLevel(logging.INFO)logger.addHandler(ch)run_mode = "train"charset_size = 100 # 3755max_steps = 12002save_steps = 2000 """# for online 3755 words trainingcheckpoint_dir = '/aiml/dfs/checkpoint/'train_data_dir = '/aiml/data/train/'test_data_dir = '/aiml/data/test/'log_dir = '/aiml/dfs/'"""checkpoint_dir = './checkpoint2/'train_data_dir = './data/'test_data_dir = './data/'log_dir = './'tf.app.flags.DEFINE_string('mode', run_mode, 'Running mode. One of {"train", "valid", "test"}')tf.app.flags.DEFINE_boolean('random_flip_up_down', True, "Whether to random flip up down")tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast") tf.app.flags.DEFINE_integer('charset_size', charset_size, "Choose the first `charset_size` character to conduct our experiment.")tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")tf.app.flags.DEFINE_integer('max_steps', max_steps, 'the max training steps ')tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")tf.app.flags.DEFINE_integer('save_steps', save_steps, "the steps to save") tf.app.flags.DEFINE_string('checkpoint_dir', checkpoint_dir, 'the checkpoint dir')tf.app.flags.DEFINE_string('train_data_dir', train_data_dir, 'the train dataset dir')tf.app.flags.DEFINE_string('test_data_dir', test_data_dir, 'the test dataset dir')tf.app.flags.DEFINE_string('log_dir', log_dir, 'the logging dir') ############################### resume trainingtf.app.flags.DEFINE_boolean('restore', True, 'whether to restore from checkpoint')##############################tf.app.flags.DEFINE_boolean('epoch', 10, 'Number of epoches')tf.app.flags.DEFINE_boolean('batch_size', 128, 'Validation batch size')FLAGS = tf.app.flags.FLAGS  class DataIterator:    def __init__(self, data_dir):        # Set FLAGS.charset_size to a small value if available computation power is limited.        truncate_path = data_dir + ('%05d' % FLAGS.charset_size)        print(truncate_path)        self.image_names = []        for root, sub_folder, file_list in os.walk(data_dir):            if root < truncate_path:                self.image_names += [os.path.join(root, file_path) for file_path in file_list]        random.shuffle(self.image_names)        self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names]     @property    def size(self):        return len(self.labels)     @staticmethod    def data_augmentation(images):        if FLAGS.random_flip_up_down:            # images = tf.image.random_flip_up_down(images)            images = tf.contrib.image.rotate(images, random.randint(0, 30) * math.pi / 180, interpolation='BILINEAR')        if FLAGS.random_brightness:            images = tf.image.random_brightness(images, max_delta=0.3)        if FLAGS.random_contrast:            images = tf.image.random_contrast(images, 0.8, 1.2)        return images     def input_pipeline(self, batch_size, num_epochs=None, aug=False):        images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)        labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)        input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)         labels = input_queue[1]        images_content = tf.read_file(input_queue[0])        images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)        if aug:            images = self.data_augmentation(images)        new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)        images = tf.image.resize_images(images, new_size)        image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,                                                          min_after_dequeue=10000)        return image_batch, label_batch  def build_graph(top_k):    # with tf.device('/cpu:0'):    keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')    images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')    labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch')     conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')    max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')    conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')    max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')    conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')    max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME')    conv_4 = slim.conv2d(max_pool_3, 512, [3, 3], [2, 2], scope="conv4", padding="SAME")    max_pool_4 = slim.max_pool2d(conv_4, [2, 2], [2, 2], padding="SAME")     flatten = slim.flatten(max_pool_4)     fc1 = slim.fully_connected(slim.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')    logits = slim.fully_connected(slim.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')        # logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))     global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)    rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)    train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step)    probabilities = tf.nn.softmax(logits)     tf.summary.scalar('loss', loss)    tf.summary.scalar('accuracy', accuracy)    merged_summary_op = tf.summary.merge_all()    predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)    accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32))     return {'images': images,            'labels': labels,            'keep_prob': keep_prob,            'top_k': top_k,            'global_step': global_step,            'train_op': train_op,            'loss': loss,            'accuracy': accuracy,            'accuracy_top_k': accuracy_in_top_k,            'merged_summary_op': merged_summary_op,            'predicted_distribution': probabilities,            'predicted_index_top_k': predicted_index_top_k,            'predicted_val_top_k': predicted_val_top_k}  def train():    print('Begin training')    train_feeder = DataIterator(FLAGS.train_data_dir)    test_feeder = DataIterator(FLAGS.test_data_dir)    with tf.Session() as sess:        train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)        test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)        graph = build_graph(top_k=1)        sess.run(tf.global_variables_initializer())        coord = tf.train.Coordinator()        threads = tf.train.start_queue_runners(sess=sess, coord=coord)        saver = tf.train.Saver()         train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)        test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')        start_step = 0        if FLAGS.restore:            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)            if ckpt:                saver.restore(sess, ckpt)                print("restore from the checkpoint {0}".format(ckpt))                start_step += int(ckpt.split('-')[-1])         logger.info(':::Training Start:::')        try:            while not coord.should_stop():                start_time = time.time()                train_images_batch, train_labels_batch = sess.run([train_images, train_labels])                feed_dict = {graph['images']: train_images_batch,                             graph['labels']: train_labels_batch,                             graph['keep_prob']: 0.8}                _, loss_val, train_summary, step = sess.run(                    [graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],                    feed_dict=feed_dict)                train_writer.add_summary(train_summary, step)                end_time = time.time()                logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))                if step > FLAGS.max_steps:                    break                if step % FLAGS.eval_steps == 1:                    test_images_batch, test_labels_batch = sess.run([test_images, test_labels])                    feed_dict = {graph['images']: test_images_batch,                                 graph['labels']: test_labels_batch,                                 graph['keep_prob']: 1.0}                    accuracy_test, test_summary = sess.run(                        [graph['accuracy'], graph['merged_summary_op']],                        feed_dict=feed_dict)                    test_writer.add_summary(test_summary, step)                    logger.info('===============Eval a batch=======================')                    logger.info('the step {0} test accuracy: {1}'                                .format(step, accuracy_test))                    logger.info('===============Eval a batch=======================')                if step % FLAGS.save_steps == 1:                    logger.info('Save the ckpt of {0}'.format(step))                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'),                               global_step=graph['global_step'])        except tf.errors.OutOfRangeError:            logger.info('==================Train Finished================')            saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])        finally:            coord.request_stop()        coord.join(threads)  def validation():    print('validation')    test_feeder = DataIterator(FLAGS.test_data_dir)     final_predict_val = []    final_predict_index = []    groundtruth = []     with tf.Session() as sess:        test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)        graph = build_graph(top_k=3)         sess.run(tf.global_variables_initializer())        sess.run(tf.local_variables_initializer())  # initialize test_feeder's inside state         coord = tf.train.Coordinator()        threads = tf.train.start_queue_runners(sess=sess, coord=coord)         saver = tf.train.Saver()        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)        if ckpt:            saver.restore(sess, ckpt)            print("restore from the checkpoint {0}".format(ckpt))         print(':::Start validation:::')        try:            i = 0            acc_top_1, acc_top_k = 0.0, 0.0            while not coord.should_stop():                i += 1                start_time = time.time()                test_images_batch, test_labels_batch = sess.run([test_images, test_labels])                feed_dict = {graph['images']: test_images_batch,                             graph['labels']: test_labels_batch,                             graph['keep_prob']: 1.0}                batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],                                                                       graph['predicted_val_top_k'],                                                                       graph['predicted_index_top_k'],                                                                       graph['accuracy'],                                                                       graph['accuracy_top_k']], feed_dict=feed_dict)                final_predict_val += probs.tolist()                final_predict_index += indices.tolist()                groundtruth += batch_labels.tolist()                acc_top_1 += acc_1                acc_top_k += acc_k                end_time = time.time()                logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)"                            .format(i, end_time - start_time, acc_1, acc_k))         except tf.errors.OutOfRangeError:            logger.info('==================Validation Finished================')            acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size            acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size            logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))        finally:            coord.request_stop()        coord.join(threads)    return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth}  def inference(image):    print('inference')    temp_image = Image.open(image).convert('L')    temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)    temp_image = np.asarray(temp_image) / 255.0    temp_image = temp_image.reshape([-1, 64, 64, 1])    with tf.Session() as sess:        logger.info('========start inference============')        # images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])        # Pass a shadow label 0. This label will not affect the computation graph.        graph = build_graph(top_k=3)        saver = tf.train.Saver()        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)        if ckpt:            saver.restore(sess, ckpt)        predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],                                              feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0})    return predict_val, predict_index  def main(_):    print(FLAGS.mode)    if FLAGS.mode == "train":        train()    elif FLAGS.mode == 'validation':        dct = validation()        result_file = 'result.dict'        logger.info('Write result into {0}'.format(result_file))        with open(result_file, 'wb') as f:            pickle.dump(dct, f)        logger.info('Write file ends')    elif FLAGS.mode == 'inference':        image_path = './data/00098/102544.png'        final_predict_val, final_predict_index = inference(image_path)        logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index,                                                                                         final_predict_val)) if __name__ == "__main__":    tf.app.run()

 

转载地址:http://ulsja.baihongyu.com/

你可能感兴趣的文章
java路径Java开发中获得非Web项目的当前项目路径
查看>>
【工具使用系列】关于 MATLAB 遗传算法与直接搜索工具箱,你需要知道的事
查看>>
Kali-linux Arpspoof工具
查看>>
PDF文档页面如何重新排版?
查看>>
基于http协议使用protobuf进行前后端交互
查看>>
bash腳本編程之三 条件判断及算数运算
查看>>
php cookie
查看>>
linux下redis安装
查看>>
弃 Java 而使用 Kotlin 的你后悔了吗?| kotlin将会是最好的开发语言
查看>>
JavaScript 数据类型
查看>>
量子通信和大数据最有市场突破前景
查看>>
StringBuilder用法小结
查看>>
对‘初学者应该选择哪种编程语言’的回答——计算机达人成长之路(38)
查看>>
如何申请开通微信多客服功能
查看>>
Sr_C++_Engineer_(LBS_Engine@Global Map Dept.)
查看>>
非监督学习算法:异常检测
查看>>
App开发中甲乙方冲突会闹出啥后果?H5 APP 开发可以改变现状吗
查看>>
jquery的checkbox,radio,select等方法总结
查看>>
Linux coredump
查看>>
Ubuntu 10.04安装水晶(Mercury)无线网卡驱动
查看>>