개복치
30
2019-04-15 22:39:28 작성 2019-04-15 22:48:37 수정됨
4
210

Tensorflow로 linear regression 모델 만드는 중 오류.. ㅠㅠ


tensorflow로 linear regression 모델 만드는 연습하는 중인데 무슨 문제인지 자꾸만 nan 오류가 뜹니다...

한 이틀을 붙잡고 고쳐봤는데도 해결이 안되요ㅠㅠ 

도움 좀 부탁드려요ㅠ 


우선 모델에 대한 코드입니다. 

Train data로 x_data 값과 y_data 값을 인풋으로 넣으면 그걸로 linear regression을 학습하고, x_prediction 값을 넣을 때 y값을 예측하는 모델입니다. 

state, state_range, plot 요런건 다 옵션이에요.

state는 진행상황을 출력할 것인가에 대한 거고

plot은 x_data에 대한 (예측) linear 함수 그래프로 그리기 옵션입니다.

def linear_regression(X1_data, Y1_data, X1_predict, state=False, state_range=20, plot=False):
    
    X_1 = tf.placeholder(tf.float32, shape=None)
    Y_1 = tf.placeholder(tf.float32, shape=None)
    
    W_1 = tf.Variable(tf.random_normal([1]), name='weight_1')
  
    b_1 = tf.Variable(tf.random_normal([1]), name='bias_1')
    
    
    hypothesis_1 = X_1 * W_1 + b_1

    
    cost_1 = tf.reduce_mean(tf.square(hypothesis_1 - Y_1))
   
    
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
    
    train_1 = optimizer.minimize(cost_1)

    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    
    for step in range(2001):
        cost_val_1,W_val_1,b_val_1,_1,hypo_1 = (sess.run([cost_1,W_1,b_1,train_1,hypothesis_1], feed_dict={X_1:X1_data, Y_1:Y1_data}))
        
        if state == True:
            if step % state_range == 0: 
                (print("step: ", step,"\ncost_val_1: ",cost_val_1, "/ W_val_1: " ,W_val_1,"/ b_val_1: ", b_val_1, "\nhypo_1", hypo_1))
        
    
    predict_Y1 = sess.run(hypothesis_1, feed_dict={X_1:[X1_predict]})

    print("predict_Y1: ", predict_Y1)

    
    if plot == True:
        import matplotlib.pyplot as plt
        hypo_list = []
        for x1 in X1_data:
            hypo = sess.run(hypothesis_1, feed_dict={X_1:[x1]})
            hypo_list.append(hypo)
            
        print("hypo_list: ", hypo_list)
        fig = plt.figure(figsize=(5,3),dpi=100)
        ax = fig.add_axes([0,0,1,1])
        ax.plot(X1_data,Y1_data,label='train data')
        ax.plot(X1_data,hypo_list,label='hypothesis data')
        plt.show()
        


다음은 여기에다가 넣은 데이터입니다.

pandas로 묶어서 보여드리기는 했지만 X1_data와 Y1_data는 각각 리스트 형태입니다.

이거를 모델 함수에다가 넣어서 돌렸습니다.

예측도 안되고 모두 nan으로 나와버립니다. 

당연히 그래프도 제대로 안나옵니다 ㅠㅠ 


데이터가 문제인가 싶어서 input을 다르게 줘봤습니다.


그래프만 봐도 예측이 매우 잘 되는 것을 볼 수 있습니다.

이게 학습률 0.01로 할 때입니다. 


저 학습률에서 위의 df_X1, df_Y1을 X,Y로 넣으면 꼭 nan 으로 출력되는데,

학습률을 바꿔봤더니 0.0000001에서부터 nan이 아닌 실수가 나옵니다.

그러나 그것도 문제인 것이.. 


이렇게 하면 nan에서는 벗어나지만, 학습률이 너무 작아서 거의 하나의 값에 수렴해버립니다.

학습할 for문을 2001에서 50001로 , 100001로도 바꿔봤지만 여전히 문제는 똑같습니다.

이건 0.01로 예측이 잘 되던 데이터에 0.0000001로 돌렸을 때 입니다.


결국 학습률을 0.01로 하면 모델이 잘 돌아가는 대신 주어진 데이터에는 nan값이 나오지만,

학습률을 0.0000001로 하면 모델이 잘 안돌아가고(학습이 더딤) 대신 주어진 데이터는 적어도 nan에서 빠져나오긴 합니다.. 


이러한 결과만 보면 모델에 문제가 있다기 보다는 데이터가 모델이랑 맞지 않아서 생긴 문제 같은데, 

도저히 왜 때문인지 모르겠습니다. 

training data인 x,y 에 대한 그래프(파란선)만 봐도 어느 정도 선형 관계가 보이는데, 

완벽하지는 못하더라도 대략 양의 방향으로 올라가는 1차 함수가 나와야 정상이라고 생각하거든요.. 

그런데 저렇게 일직선으로 나오는건 말도 안되는거 같고

nan값은 더더욱 이해가 안됩니다...


아시는 분 도움 좀 부탁드려요 ㅠㅠ

구글링도 해보고 물어도 봤지만 도저히 혼자서 해결이 안되서 올려요..

0
0
  • 답변 4

  • dohyeong
    335
    2019-04-15 23:42:02

    X값을 1985 - 2015가 아니라 0 - 30로 바꾸고 learning rate를 적절히 조절해 보세요.

    1
  • adward13
    3
    2019-04-16 10:40:39

    일단 값이 많이 벗어나는 2015년도 데이터를 삭제해 보시고요 그래디언트 디센트 보다는 아담 옵티마이져를 한번써보세요 learing rate는 0.01정도면 적당합니다

    그리고 반복횟수를 100회서 부터 천천히 100단위로 증가시켜서 해보세요 아니면 1서부터 어떻게 변하는지 살펴보셔도 됩니다.

    1
  • 개복치
    30
    2019-04-16 11:13:22 작성 2019-04-16 11:51:30 수정됨

    @dohyeong // 오.. 0~30으로 바꾸고 learning rate를 0.001로 바꾸니까 딱 원하던 모양이 나오네요! 0.003일 때 다른 데이터에도 가장 잘 맞는거 같아요.


    그런데 0~30이나 1985~2015나 같은 증가폭인데 왜 값이 다른지 여전히 궁금하긴 하네요.

    여하튼 답변 감사합니다! 

    0
  • 개복치
    30
    2019-04-16 11:33:26

    @adward13 // 답변 감사합니다!

    2015값이 원래 32.428333인데 모델이 주로 마지막 값에 수렴해서 학습되는 것 같아 실험삼아 바꿔봤어요. 그래도 문제는 여전하더라구요 ㅠㅠ

    말씀주신대로 다시 원래값으로 돌려놓고 윗댓글대로 x데이터를 0~30으로 바꿨더니 어느 정도 모양이 나오네요. 

    알려주신 AdamOptimizer로 학습률 0.003 (반복횟수는 그대로)으로 돌려봤는데요. 무슨 이유인지 두번째 데이터에는 정 반대의 결과가 나왔어요.


    이건 학습률 0.1로 했을 때의 결과입니다.

    아담의 경우 학습률 0.1로 할 때 조금 더 잘 돌아가는 것 같네요. 

    아담은 작동원리를 몰라서 조금 더 공부해봐야 왜 인지 이해될 것 같습니다.

    감사해요! 

    0
  • 로그인을 하시면 답변을 등록할 수 있습니다.