MY MEMO
[MACHINE LEARNING] Linear Regression 의 cost 최소화의 TensorFlow 구현 본문
[MACHINE LEARNING] Linear Regression 의 cost 최소화의 TensorFlow 구현
l_j_yeon 2017. 4. 7. 15:24이번에는 지난번에 했던 cost function을 이용하여 gradient descent 그래프를 그려볼것이다.
import matplotlib.pyplot 를 하기 위해서는 먼저 matplotlib를 깔아야한다.
window 사용자는 matplotlib를 pip로 깔고 font_manager.py에 코드를 추가시킨다.
출처 : http://blog.naver.com/udgttl12/220954605729
그럼 코드를 살펴보자.
data값이 주어졌고 W 0.1*i로 설정하였다.
그러기 위해 처음 W를 placeholder function을 사용하였다.
cost function을 구하고 session에 넣어준다.
cost function의 결과값을 W_val와 cost_val에 넣어주고 그것을 plot를 이용하여 그래프를 출력시키면
위와같은 그래프가 나온다. J function의 그래프이다.
여기서 step을 이용하여 가장 작은 값을 찾으면
모든 데이터와 차이가 제일 작은 그래프를 그릴 수 있는 변수를 찾을 수 있다.
그래프를 그린 이후에 Gradient descent를 구하는 코드이다.
learning rate는 한 step의 크기이고
gradient 함수를 코드로 나타내면 위와 같다.
그럼 나온 값을 W에 계속 update해줘야하기 때문에
update변수를 사용하여 계산한 값을 W에 assign해준다.
cost가 점점 작아지고 w의 값이 1로 도달하는 것을 볼 수있다.
만약 w가 5부터 시작한다면 gradient descent 는 어떤식으로 돌아갈까?
w가 5이면 그래프의 오른쪽 위 방향부터 시작하여 minimum값을 찾는다.
위의 코드를 실행시키면 9번째만에 1.0이라는 완벽한 정답 값을 찾는다.
만약에 w가 왼쪽 위 갑부터 시작하면 값은 어떻게 변할까?
-3부터 w가 시작해도 마찬가지로 빠르게 정확한 답을 찾는다.
이번에는 우리가 계산한 gradient와 함수로 계산한 gradient가 과연 같은 값을 갖는지 비교해보자
gvs = optimizer.compute_gradients(cost, [w]) 이렇게 변경해주면 오류가 해결된다.
위를 실행하면
우리가 만든 함수의 gradient와 w 가 앞에 출력되고 이후 computer에서 실행되는 gradient와 w가 출력된다.
위를 보면 두개의 값이 정확히 일치한다는 것을 알 수 있다.
+)
tensorflow 홈페이지에가면 optimizers에 대한 설명이 코드와 함께 나와있다.
직관적으로 optimizers함수가 무엇을 할지 알수있지만
정확하게 공부를 하는 것도 필요하다.
+) 최적화(optimization) : 손실함수(loss function)을 최소화시카는 파라미터(parameter/weight, )들을 찾는 과정
+) compute_gradients() 는 손실에 대한 변수들의 기울기를 tf.gradient() 로 구합니다. 함수의 인자로 변수들을 넣어 주지 않아도, TF가 알아서 이 손실을 계산하는 데 필요한 모든 변수를 알아서 가져오고 그 변수들에 대해 기울기를 구합니다.
물론, var_list 라는 인자를 통해 특정 변수들에 대해서만 기울기를 계산할 수도 있습니다! 전체 그래프의 일부 변수만을 업데이트하고 싶을 때 유용하게 쓰일 수 있습니다.
apply_gradients() 는 위에서 구한 기울기 (혹은 외부에서 구해 온 기울기) 를 변수들에 업데이트합니다
출처 : http://deepestdocs.readthedocs.io/en/latest/001_tensorflow/0010/
'MACHINE LEARNING > Sung Kim - 실습' 카테고리의 다른 글
[MACHINE LEARNING] Softmax Function (0) | 2017.04.19 |
---|---|
[MACHINE LEARNING] Logistic classifier (0) | 2017.04.19 |
[MACHINE LEARNING] Multi Variable Linear Regression (0) | 2017.04.17 |
[MACHINE LEARNING] 번외 tensorflow matplotlib 설치에 관해 (2) | 2017.04.07 |
[MACHINE LEARNING] Linear Regression (0) | 2017.04.06 |