RNN学习笔记 之 利用tensorflow-lstm实现sin预测

RNN用于处理不定长的序列数据,而简谐波预测就是一个很好的例子,显然,对于机器学习,sin(x)、cos(x)、a*sin(x) + b*cos(x)甚至x*sin(x)都没有任何难度上的差别。

那么我们就用RNN来预测吧!

(CNMB……由于一个偶然的错误,我发现这个问题直接线性规划类型的神经网络就可以很精确的计算出来!!!而且LSTM根本比不上!!!)

考虑

sin(a(k-1)+b)、sin(ak+b) 和sin(a(k+1)+b)的关系

sin(a(k-1)+b) = sin(ak+b-a) = sin(ak+b)cos(a) – cos(ak+b)sin(a)

sin(a(k+1)+b) = sin(ak+b+a) = sin(ak+b)cos(a) + cos(ak+b)sin(a)

sin(a(k+1)+b) = -sin(a(k-1)+b) + 2sin(ak+b)cos(a)

其中由于a我取的是整数,所以非常容易估计,于是实际上sin数列的每一项可以由前两项线性表示出来……

即使a不是整数,也可通过三角带换,得到近似于线性表达的公式

但是现在就出现这样一个问题。

既然是如此简单的问题,LSTM的表现为啥这么渣呢???

原因可能如下:

  1. LSTM变量太多了,训练会出现梯度消失的问题
  2. 使用LSTM,倒数第二项作用于倒数第一项,中间有若干非线性层,导致本来非常简单的线性组合变得非常难以模拟了!
  3. 写错了……

脑残的把最后一层加了一个sigmoid函数,然而并不知道sigmoid的返回值是一个正数,并不包含sin(x)的值域区间

import tensorflow as tf
import numpy as np
import talib
import random
from matplotlib import pyplot

max_length = 20
total_branches = 10010
batch_size = 31
hidden_size = 16
start_position = 10
test_batch_index = 0
num_input_classes = 1
difficulity = 10
total_layers = 1

test_x = None
test_y = None
test_seqlen = None

def gen_sine_wave(length):
    L = random.randint(1,difficulity)
    R = random.randint(1,difficulity)
    if L>R:
        L,R = R,L
    x = np.linspace(L,R,length+1)
    y = np.sin(x)
    return y[:-1],y[-1]

def generate_batches():
    global test_x,test_y,test_seqlen
    test_x = []
    test_y = []
    test_seqlen = []
    test_batch_len = 0
    for i in range(0,total_branches):
        length = random.randint(start_position,max_length)
        x,y = gen_sine_wave(length)
        x = np.append(np.array([t for t in x]),np.array([0 for i in range(max_length-length)]))
        x = np.reshape(x,[max_length,num_input_classes])
        test_seqlen.append(length)
        test_x.append(x)
        test_y.append(y)

def next_batch(total_branches = total_branches):
    global test_batch_index,test_x,test_y,test_seqlen
    if test_batch_index == total_branches:
        test_batch_index = 0
    xs = test_x[test_batch_index:min(test_batch_index+batch_size,total_branches)]
    ys = test_y[test_batch_index:min(test_batch_index+batch_size,total_branches)]
    seqlens = test_seqlen[test_batch_index:min(test_batch_index+batch_size,total_branches)]
    test_batch_index = min(test_batch_index + batch_size,total_branches)
    return xs,ys,seqlens


if __name__ == '__main__':
    generate_batches()
    #sess = tf.InteractiveSession()
    input_data = tf.placeholder(tf.float32,[None,max_length,num_input_classes])
    input_seqlen = tf.placeholder(tf.int32,[None])
    input_labels = tf.placeholder(tf.float32,[None])
    keep_prob = tf.placeholder(tf.float32)
    current_batch_size = tf.shape(input_data)[0]

    with tf.name_scope('lstm_layer'):
        method = 'lstm'
        data = input_data
        print(data.shape)#10*20*1
        if method == 'lstm':
            def create_lstm():
                lstm_cell = tf.contrib.rnn.BasicLSTMCell(hidden_size,state_is_tuple = True)
                #lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=keep_prob,state_keep_prob=keep_prob)
                return lstm_cell
            lstm_cell = create_lstm()
            #lstm_cell = tf.nn.rnn_cell.MultiRNNCell([create_lstm() for i in range(total_layers)], state_is_tuple=True)
            #state = lstm_cell.zero_state(current_batch_size,tf.float32)
            data = tf.unstack(data,axis=1)
            outputs,state = tf.nn.static_rnn(lstm_cell,data,dtype=tf.float32,sequence_length=input_seqlen)
            outputs = tf.stack(outputs,axis=1)
            #outputs =  tf.transpose(outputs,[1,0,2])
        else:
            gru_cell = tf.contrib.rnn.GRUCell(hidden_size)
            outputs,state = tf.nn.dynamic_rnn(gru_cell,data,dtype=tf.float32,sequence_length=input_seqlen)

        print(outputs.shape)
        output = tf.gather_nd(outputs, tf.stack([tf.range(current_batch_size), input_seqlen-1], axis=1))

    with tf.name_scope('final_layer'):
        final_w = tf.Variable(tf.truncated_normal([hidden_size,1],.1))
        tf.summary.histogram('weight',final_w)
        final_b = tf.Variable(tf.constant(.1,shape=[1]))
        tf.summary.histogram('bias',final_b)

    with tf.name_scope('output_layer'):
        result = tf.tanh(tf.matmul(output,final_w) + final_b)
        print(result.shape)
        result = tf.reshape(result,[-1])
        loss = tf.reduce_mean(tf.losses.mean_squared_error(labels = input_labels,predictions=result))
        tf.summary.scalar('loss',loss)

    optimizer = tf.train.AdamOptimizer().minimize(loss)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        merged = tf.summary.merge_all()
        writer = tf.summary.FileWriter("logs/",sess.graph)

        steps = 0
        while True:
            batch_x , batch_y, batch_seqlen = next_batch()
            sess.run(optimizer, feed_dict={input_data:batch_x,
                input_labels:batch_y,
                input_seqlen:batch_seqlen,
                keep_prob:.95})
            if steps%50 == 0:
                print(steps,sess.run(loss,feed_dict={input_data:batch_x,input_labels:batch_y,input_seqlen:batch_seqlen,keep_prob:1}))
                print("PREDICT",sess.run(result,feed_dict={input_data:batch_x,input_labels:batch_y,input_seqlen:batch_seqlen,keep_prob:1}))
                print("LABEL",sess.run(input_labels,feed_dict={input_data:batch_x,input_labels:batch_y,input_seqlen:batch_seqlen,keep_prob:1}))
                merged_result = sess.run(merged, feed_dict={input_data:batch_x,
                    input_labels: batch_y,
                    input_seqlen: batch_seqlen,
                    keep_prob:1})
                writer.add_summary(merged_result,steps)
            steps += 1
            if steps%1000 == 0:
                batch_x , batch_y, batch_seqlen = next_batch()
                current_result = sess.run(result,feed_dict={input_data:batch_x,input_seqlen:batch_seqlen,keep_prob:1})
                X = batch_x[0][:batch_seqlen[0]]
                Y = current_result[0]
                Z = batch_y[0]
                plt = pyplot.figure()
                pyplot.plot(np.linspace(1,len(X),len(X)),X)
                pyplot.plot(len(X)+1,Y,'r+')
                pyplot.plot(len(X)+1,Z,'r^')
                pyplot.show()

 

发表评论

电子邮件地址不会被公开。 必填项已用*标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据