User Tools

Site Tools


gradient_descent

Gradient Descent

점차하강 = 조금씩 깍아서 원하는 기울기 (미분값) 찾기
prerequisite:
표준편차 추론에서 평균을 사용하는 이유: 실험적_수학적_이해
deriviation of a and b in a simple regression

위의 문서는 a, b에 대한 값을 미분법을 이용해서 직접 구하였다. 컴퓨터로는 이렇게 하기가 쉽지 않다. 그렇다면 이 값을 반복계산을 이용해서 추출하는 방법은 없을까? gradient descent

우선 위의 문서에서 (두번째) 최소값이 되는 SS값을 찾는다고 설명했는데, 이는 MS값으로 대체해서 생각해도 된다.

\begin{eqnarray*} \text{MS} & = & \frac {\text{SS}}{n} \end{eqnarray*}

\begin{eqnarray*} \text{for a (constant)} \\ \\ \dfrac{\text{d}}{\text{dv}} \text{MSE (Mean Square Error)} & = & \dfrac{\text{d}}{\text{dv}} \frac {\sum{(Y_i - (a + bX_i))^2}} {N} \\ & = & \sum \dfrac{\text{d}}{\text{dv}} \frac{{(Y_i - (a + bX_i))^2}} {N} \\ & = & \sum{2 \frac{1}{N} (Y_i - (a + bX_i))} * (-1) \;\;\;\; \\ & \because & \dfrac{\text{d}}{\text{dv for a}} (Y_i - (a+bX_i)) = -1 \\ & = & -2 \frac{1}{N} \sum{(Y_i - (a + bX_i))} \\ \end{eqnarray*}

library(tidyverse)
# a simple example
# statquest explanation
x <- c(0.5, 2.3, 2.9)
y <- c(1.4, 1.9, 3.2)

rm(list=ls())
# set.seed(191)
n <- 500
x <- rnorm(n, 5, 1.2)
y <- 2.14 * x + rnorm(n, 0, 4)

# data <- data.frame(x, y)
data <- tibble(x = x, y = y)
data

mo <- lm(y~x)
summary(mo)

# set.seed(191)
# Initialize random betas
b1 = rnorm(1)
b0 = rnorm(1)

# Predict function:
predict <- function(x, b0, b1){
  return (b0 + b1 * x)
}

# And loss function is:
residuals <- function(predictions, y) {
  return(y - predictions)
}

loss_mse <- function(predictions, y){
  residuals = y - predictions
  return(mean(residuals ^ 2))
}

predictions <- predict(x, b0, b1)
residuals <- residuals(predictions, y)
loss = loss_mse(predictions, y)

temp.sum <- data.frame(x, y, b0, b1,predictions, residuals)
temp.sum

print(paste0("Loss is: ", round(loss)))

gradient <- function(x, y, predictions){
  dinputs = y - predictions
  db1 = -2 * mean(x * dinputs)
  db0 = -2 * mean(dinputs)
  
  return(list("db1" = db1, "db0" = db0))
}

gradients <- gradient(x, y, predictions)
print(gradients)

# Train the model with scaled features
x_scaled <- (x - mean(x)) / sd(x)

learning_rate = 1e-1

# Record Loss for each epoch:
logs = list()
bs=list()
b0s = c()
b1s = c()
mse = c()

nlen <- 80
for (epoch in 1:nlen){
  # Predict all y values:
  predictions = predict(x_scaled, b0, b1)
  loss = loss_mse(predictions, y)
  mse = append(mse, loss)
  
  logs = append(logs, loss)
  
  if (epoch %% 10 == 0){
    print(paste0("Epoch: ",epoch, ", Loss: ", round(loss, 5)))
  }
  
  gradients <- gradient(x_scaled, y, predictions)
  db1 <- gradients$db1
  db0 <- gradients$db0
  
  b1 <- b1 - db1 * learning_rate
  b0 <- b0 - db0 * learning_rate
  b0s <- append(b0s, b0)
  b1s <- append(b1s, b1)
}
# I must unscale coefficients to make them comprehensible
b0 =  b0 - (mean(x) / sd(x)) * b1
b1 = b1 / sd(x)

b0s <- b0s - (mean(x) /sd(x)) * b1s
b1s <- b1s / sd(x)

parameters <- tibble(data.frame(b0s, b1s, mse))

cat(paste0("Inclination: ", b1, ", \n", "Intercept: ", b0, "\n"))
summary(lm(y~x))$coefficients

ggplot(data, aes(x = x, y = y)) + 
  geom_point(size = 2) + 
  geom_abline(aes(intercept = b0s, slope = b1s),
              data = parameters, linewidth = 0.5, color = 'red') + 
  theme_classic() +
  geom_abline(aes(intercept = b0s, slope = b1s), 
              data = parameters %>% slice_head(), 
              linewidth = 0.5, color = 'blue') + 
  geom_abline(aes(intercept = b0s, slope = b1s), 
              data = parameters %>% slice_tail(), 
              linewidth = 1, color = 'green') +
  labs(title = 'Gradient descent: blue: start, green: end')
data
parameters
> rm(list=ls())
> # set.seed(191)
> n <- 500
> x <- rnorm(n, 5, 1.2)
> y <- 2.14 * x + rnorm(n, 0, 4)
> 
> # data <- data.frame(x, y)
> data <- tibble(x = x, y = y)
> data
# A tibble: 500 × 2
       x     y
   <dbl> <dbl>
 1  4.48 11.1 
 2  6.45 10.2 
 3  6.41 11.7 
 4  5.35 15.8 
 5  5.17  8.84
 6  3.64  1.36
 7  6.35 10.9 
 8  3.30 10.7 
 9  6.30  6.98
10  3.81  5.22
# ℹ 490 more rows
# ℹ Use `print(n = ...)` to see more rows
> 
> mo <- lm(y~x)
> summary(mo)

Call:
lm(formula = y ~ x)

Residuals:
     Min       1Q   Median       3Q      Max 
-10.2534  -2.6615   0.0087   2.7559   9.7626 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)   0.1281     0.7108    0.18    0.857    
x             2.1606     0.1388   15.57   <2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 3.822 on 498 degrees of freedom
Multiple R-squared:  0.3273,	Adjusted R-squared:  0.326 
F-statistic: 242.3 on 1 and 498 DF,  p-value: < 2.2e-16

> 
> # set.seed(191)
> # Initialize random betas
> b1 = rnorm(1)
> b0 = rnorm(1)
> 
> # Predict function:
> predict <- function(x, b0, b1){
+   return (b0 + b1 * x)
+ }
> 
> # And loss function is:
> residuals <- function(predictions, y) {
+   return(y - predictions)
+ }
> 
> loss_mse <- function(predictions, y){
+   residuals = y - predictions
+   return(mean(residuals ^ 2))
+ }
> 
> predictions <- predict(x, b0, b1)
> residuals <- residuals(predictions, y)
> loss = loss_mse(predictions, y)
> 
> temp.sum <- data.frame(x, y, b0, b1,predictions, residuals)
> temp.sum
           x          y        b0        b1 predictions   residuals
1   4.479742 11.1333005 0.5808843 0.9861742    4.998691  6.13460981
2   6.452609 10.1579559 0.5808843 0.9861742    6.944280  3.21367543
3   6.413290 11.6979692 0.5808843 0.9861742    6.905505  4.79246407
4   5.345206 15.7547560 0.5808843 0.9861742    5.852188  9.90256769
5   5.173454  8.8357907 0.5808843 0.9861742    5.682811  3.15297971
6   3.636604  1.3575950 0.5808843 0.9861742    4.167209 -2.80961437
7   6.348325 10.8649166 0.5808843 0.9861742    6.841439  4.02347773
8   3.299825 10.7060946 0.5808843 0.9861742    3.835087  6.87100778
9   6.300136  6.9805495 0.5808843 0.9861742    6.793915  0.18663406
10  3.812044  5.2219952 0.5808843 0.9861742    4.340224  0.88177101
11  3.436925  9.7518360 0.5808843 0.9861742    3.970291  5.78154462
12  5.883357 14.4497406 0.5808843 0.9861742    6.382899  8.06684128
13  4.328653 14.5001264 0.5808843 0.9861742    4.849690  9.65043604
14  4.130057 10.0931558 0.5808843 0.9861742    4.653840  5.43931617
15  5.322393  9.1121695 0.5808843 0.9861742    5.829691  3.28247861
16  4.526528  8.1055219 0.5808843 0.9861742    5.044829  3.06069276
17  3.817400  4.4323299 0.5808843 0.9861742    4.345505  0.08682455
18  3.387983 -0.2180968 0.5808843 0.9861742    3.922026 -4.14012281
19  4.270354  8.2876796 0.5808843 0.9861742    4.792198  3.49548203
20  5.822266 10.5076073 0.5808843 0.9861742    6.322653  4.18495463
21  6.009412  9.8284624 0.5808843 0.9861742    6.507211  3.32125112
22  5.785644 12.2267578 0.5808843 0.9861742    6.286537  5.94022063
23  5.103190 10.6399113 0.5808843 0.9861742    5.613518  5.02639300
24  5.381166 18.1917469 0.5808843 0.9861742    5.887652 12.30409506
25  2.812116 11.6562811 0.5808843 0.9861742    3.354121  8.30216030
26  3.146225  5.1230054 0.5808843 0.9861742    3.683610  1.43939522
27  4.883188 10.6312680 0.5808843 0.9861742    5.396558  5.23470953
28  4.955458 11.8617949 0.5808843 0.9861742    5.467830  6.39396521
29  3.952036  8.4087869 0.5808843 0.9861742    4.478281  3.93050617
30  6.739458 18.9309648 0.5808843 0.9861742    7.227164 11.70380055
31  4.882959  9.7563509 0.5808843 0.9861742    5.396333  4.36001803
32  5.070600 17.4532502 0.5808843 0.9861742    5.581379 11.87187072
33  5.257397 10.4619521 0.5808843 0.9861742    5.765594  4.69635807
34  3.921518  0.6379410 0.5808843 0.9861742    4.448184 -3.81024292
35  5.112554  7.8998652 0.5808843 0.9861742    5.622753  2.27711243
36  5.783692 17.7148068 0.5808843 0.9861742    6.284612 11.43019492
37  5.756150 10.1595169 0.5808843 0.9861742    6.257451  3.90206554
38  6.010004 15.6295228 0.5808843 0.9861742    6.507795  9.12172776
39  7.527341 10.8972192 0.5808843 0.9861742    8.004154  2.89306557
40  3.718376 11.4857450 0.5808843 0.9861742    4.247851  7.23789408
41  3.816369  2.7557710 0.5808843 0.9861742    4.344489 -1.58871815
42  3.904699 11.9299211 0.5808843 0.9861742    4.431598  7.49832314
43  4.889957 17.8620975 0.5808843 0.9861742    5.403234 12.45886371
44  3.456463  7.4961713 0.5808843 0.9861742    3.989559  3.50661222
45  5.274541 15.1190395 0.5808843 0.9861742    5.782501  9.33653899
46  5.064607 13.5273619 0.5808843 0.9861742    5.575469  7.95189250
47  3.175056  3.5614626 0.5808843 0.9861742    3.712043 -0.15058044
48  7.179882 10.5044662 0.5808843 0.9861742    7.661499  2.84296697
49  4.098562 16.6057557 0.5808843 0.9861742    4.622780 11.98297543
50  4.532198  8.3301716 0.5808843 0.9861742    5.050421  3.27975067
51  7.248676 12.4619092 0.5808843 0.9861742    7.729341  4.73256789
52  4.440335 14.0170577 0.5808843 0.9861742    4.959828  9.05722939
53  6.565150 15.0841504 0.5808843 0.9861742    7.055266  8.02888421
54  5.886684  8.9326774 0.5808843 0.9861742    6.386180  2.54649717
55  3.417331  3.3471923 0.5808843 0.9861742    3.950968 -0.60377576
56  5.917124  6.0906638 0.5808843 0.9861742    6.416200 -0.32553581
57  3.453444 11.2981952 0.5808843 0.9861742    3.986582  7.31161320
58  4.825523  4.3895944 0.5808843 0.9861742    5.339690 -0.95009592
59  4.649551  4.5605470 0.5808843 0.9861742    5.166152 -0.60560470
60  5.065955 15.8326852 0.5808843 0.9861742    5.576798 10.25588679
61  5.189293 11.3101911 0.5808843 0.9861742    5.698431  5.61175986
62  5.769762  9.5544913 0.5808843 0.9861742    6.270875  3.28361629
63  6.136073 14.2736277 0.5808843 0.9861742    6.632121  7.64150631
64  5.079882  9.4934102 0.5808843 0.9861742    5.590533  3.90287732
65  2.407474  1.5588896 0.5808843 0.9861742    2.955073 -1.39618346
66  5.754148 13.7113431 0.5808843 0.9861742    6.255477  7.45586639
67  4.674475 15.8016174 0.5808843 0.9861742    5.190731 10.61088686
68  5.690545 19.2021795 0.5808843 0.9861742    6.192753 13.00942611
69  5.204651  8.6483243 0.5808843 0.9861742    5.713577  2.93474768
70  4.331535 11.2805649 0.5808843 0.9861742    4.852533  6.42803216
71  3.605775  5.2645847 0.5808843 0.9861742    4.136806  1.12777848
72  6.863329 10.9046518 0.5808843 0.9861742    7.349322  3.55532960
73  6.296937  6.9897624 0.5808843 0.9861742    6.790761  0.19900104
74  5.377210 13.8655597 0.5808843 0.9861742    5.883750  7.98180970
75  5.403542  9.0932758 0.5808843 0.9861742    5.909718  3.18355804
76  4.097157  7.8821791 0.5808843 0.9861742    4.621395  3.26078434
77  3.994292  3.7220445 0.5808843 0.9861742    4.519952 -0.79790775
78  3.898466  8.3450728 0.5808843 0.9861742    4.425451  3.91962211
79  6.201434 19.8080454 0.5808843 0.9861742    6.696578 13.11146709
80  6.972587 15.4954552 0.5808843 0.9861742    7.457070  8.03838546
81  5.512087 10.1381343 0.5808843 0.9861742    6.016763  4.12137159
82  5.463011 12.2079314 0.5808843 0.9861742    5.968365  6.23956610
83  5.840064 15.6354616 0.5808843 0.9861742    6.340205  9.29525659
84  4.628974  8.7689346 0.5808843 0.9861742    5.145859  3.62307535
85  3.775477  7.9325251 0.5808843 0.9861742    4.304162  3.62836296
86  4.789949  5.2633379 0.5808843 0.9861742    5.304608 -0.04127026
87  5.878382 22.5915677 0.5808843 0.9861742    6.377993 16.21357470
88  4.992651 14.0194907 0.5808843 0.9861742    5.504508  8.51498287
89  3.181054 10.1407775 0.5808843 0.9861742    3.717958  6.42281930
90  5.530133 11.2858888 0.5808843 0.9861742    6.034559  5.25133021
91  5.141758  8.7832124 0.5808843 0.9861742    5.651554  3.13165855
92  4.911979 13.4673585 0.5808843 0.9861742    5.424951  8.04240727
93  6.363601  9.6015990 0.5808843 0.9861742    6.856503  2.74509561
94  4.590408 13.5678458 0.5808843 0.9861742    5.107827  8.46001920
95  3.394860 11.9961020 0.5808843 0.9861742    3.928808  8.06729420
96  5.054608 10.9710834 0.5808843 0.9861742    5.565608  5.40547500
97  5.631312 15.0984710 0.5808843 0.9861742    6.134339  8.96413204
98  4.528634 13.5841385 0.5808843 0.9861742    5.046906  8.53723213
99  6.368627 14.4842546 0.5808843 0.9861742    6.861460  7.62279469
100 4.502220  6.1382455 0.5808843 0.9861742    5.020858  1.11738800
101 3.363460  5.3607519 0.5808843 0.9861742    3.897841  1.46291038
102 5.489312  9.6561866 0.5808843 0.9861742    5.994303  3.66188397
103 2.208872 -0.2618084 0.5808843 0.9861742    2.759217 -3.02102560
104 4.703816 15.7759129 0.5808843 0.9861742    5.219667 10.55624606
105 2.594336 10.8674426 0.5808843 0.9861742    3.139352  7.72809101
106 4.360380 10.0070566 0.5808843 0.9861742    4.880979  5.12607762
107 4.255169  2.7966159 0.5808843 0.9861742    4.777223 -1.98060674
108 6.229481 11.1195615 0.5808843 0.9861742    6.724238  4.39532397
109 3.429806  6.8987700 0.5808843 0.9861742    3.963271  2.93549927
110 8.152707 22.5016746 0.5808843 0.9861742    8.620874 13.88080087
111 2.640736 -3.1625558 0.5808843 0.9861742    3.185110 -6.34766542
112 4.741218 18.5919552 0.5808843 0.9861742    5.256551 13.33540377
113 5.488745 12.9386181 0.5808843 0.9861742    5.993743  6.94487516
114 3.227389  5.3641700 0.5808843 0.9861742    3.763652  1.60051759
115 4.443698 13.9122126 0.5808843 0.9861742    4.963145  8.94906779
116 5.338956  8.8612430 0.5808843 0.9861742    5.846025  3.01521767
117 6.797698 15.9969094 0.5808843 0.9861742    7.284598  8.71231099
118 7.022180 13.5295313 0.5808843 0.9861742    7.505977  6.02355391
119 5.473466 11.9334015 0.5808843 0.9861742    5.978675  5.95472630
120 6.024003 13.1870525 0.5808843 0.9861742    6.521600  6.66545200
121 5.091827  3.2878229 0.5808843 0.9861742    5.602313 -2.31449027
122 4.493815 12.2146866 0.5808843 0.9861742    5.012568  7.20211816
123 7.112794 19.9524515 0.5808843 0.9861742    7.595338 12.35711319
124 5.225292 16.7979964 0.5808843 0.9861742    5.733932 11.06406408
125 5.064472 15.5614482 0.5808843 0.9861742    5.575336  9.98611181
126 5.552849  9.4032924 0.5808843 0.9861742    6.056961  3.34633146
127 3.902294  8.0228812 0.5808843 0.9861742    4.429226  3.59365505
128 6.951468 17.2616673 0.5808843 0.9861742    7.436242  9.82542486
129 5.217489 14.3563415 0.5808843 0.9861742    5.726237  8.63010443
130 1.832789  6.8450024 0.5808843 0.9861742    2.388333  4.45666924
131 5.170683  9.6343272 0.5808843 0.9861742    5.680078  3.95424866
132 6.104459 11.7304563 0.5808843 0.9861742    6.600944  5.12951245
133 4.584068 14.0684806 0.5808843 0.9861742    5.101574  8.96690638
134 6.594802 11.5921896 0.5808843 0.9861742    7.084508  4.50768117
135 4.492839 11.3445727 0.5808843 0.9861742    5.011606  6.33296667
136 4.835051 10.0457136 0.5808843 0.9861742    5.349087  4.69662702
137 4.495956 12.9092820 0.5808843 0.9861742    5.014680  7.89460166
138 4.704205 11.3171541 0.5808843 0.9861742    5.220049  6.09710462
139 6.495443  8.1764858 0.5808843 0.9861742    6.986523  1.18996263
140 4.475086  5.9729101 0.5808843 0.9861742    4.994099  0.97881090
141 3.089286  2.3112953 0.5808843 0.9861742    3.627459 -1.31616349
142 5.959849 12.4859299 0.5808843 0.9861742    6.458334  6.02759586
143 6.474426 13.7818185 0.5808843 0.9861742    6.965796  6.81602242
144 5.567374 14.4706101 0.5808843 0.9861742    6.071285  8.39932527
145 4.052363  7.9027988 0.5808843 0.9861742    4.577220  3.32557871
146 5.074791 10.0325838 0.5808843 0.9861742    5.585512  4.44707188
147 6.831115  9.3577103 0.5808843 0.9861742    7.317553  2.04015701
148 3.607346 10.5973814 0.5808843 0.9861742    4.138356  6.45902511
149 6.896116 21.4919066 0.5808843 0.9861742    7.381656 14.11025058
150 6.317008 15.9406225 0.5808843 0.9861742    6.810555  9.13006776
151 5.168403 15.8355851 0.5808843 0.9861742    5.677830 10.15775499
152 4.434618  7.1342370 0.5808843 0.9861742    4.954190  2.18004689
153 5.891072  8.7559727 0.5808843 0.9861742    6.390507  2.36546522
154 3.512759  8.4225211 0.5808843 0.9861742    4.045077  4.37744430
155 3.946038 11.5013653 0.5808843 0.9861742    4.472365  7.02900004
156 6.756402 13.1194434 0.5808843 0.9861742    7.243873  5.87557008
157 4.691839 13.2412345 0.5808843 0.9861742    5.207855  8.03337960
158 4.691552 11.1985804 0.5808843 0.9861742    5.207571  5.99100894
159 4.025405 13.5279553 0.5808843 0.9861742    4.550635  8.97732006
160 6.330436 12.8980948 0.5808843 0.9861742    6.823797  6.07429736
161 5.213079 15.7716904 0.5808843 0.9861742    5.721888 10.04980247
162 6.378086 11.6205214 0.5808843 0.9861742    6.870788  4.74973340
163 7.228954 11.5549520 0.5808843 0.9861742    7.709892  3.84505980
164 3.621615 14.6260406 0.5808843 0.9861742    4.152427 10.47361310
165 5.235016  8.3091022 0.5808843 0.9861742    5.743522  2.56558049
166 3.767853 11.1148677 0.5808843 0.9861742    4.296644  6.81822388
 [ reached 'max' / getOption("max.print") -- omitted 334 rows ]
> 
> print(paste0("Loss is: ", round(loss)))
[1] "Loss is: 46"
> 
> gradient <- function(x, y, predictions){
+   dinputs = y - predictions
+   db1 = -2 * mean(x * dinputs)
+   db0 = -2 * mean(dinputs)
+   
+   return(list("db1" = db1, "db0" = db0))
+ }
> 
> gradients <- gradient(x, y, predictions)
> print(gradients)
$db1
[1] -57.11316

$db0
[1] -10.77174

> 
> # Train the model with scaled features
> x_scaled <- (x - mean(x)) / sd(x)
> 
> learning_rate = 1e-1
> 
> # Record Loss for each epoch:
> logs = list()
> bs=list()
> b0s = c()
> b1s = c()
> mse = c()
> 
> nlen <- 80
> for (epoch in 1:nlen){
+   # Predict all y values:
+   predictions = predict(x_scaled, b0, b1)
+   loss = loss_mse(predictions, y)
+   mse = append(mse, loss)
+   
+   logs = append(logs, loss)
+   
+   if (epoch %% 10 == 0){
+     print(paste0("Epoch: ",epoch, ", Loss: ", round(loss, 5)))
+   }
+   
+   gradients <- gradient(x_scaled, y, predictions)
+   db1 <- gradients$db1
+   db0 <- gradients$db0
+   
+   b1 <- b1 - db1 * learning_rate
+   b0 <- b0 - db0 * learning_rate
+   b0s <- append(b0s, b0)
+   b1s <- append(b1s, b1)
+ }
[1] "Epoch: 10, Loss: 16.50445"
[1] "Epoch: 20, Loss: 14.56909"
[1] "Epoch: 30, Loss: 14.54677"
[1] "Epoch: 40, Loss: 14.54651"
[1] "Epoch: 50, Loss: 14.54651"
[1] "Epoch: 60, Loss: 14.54651"
[1] "Epoch: 70, Loss: 14.54651"
[1] "Epoch: 80, Loss: 14.54651"
> # I must unscale coefficients to make them comprehensible
> b0 =  b0 - (mean(x) / sd(x)) * b1
> b1 = b1 / sd(x)
> 
> b0s <- b0s - (mean(x) /sd(x)) * b1s
> b1s <- b1s / sd(x)
> 
> parameters <- tibble(data.frame(b0s, b1s, mse))
> 
> cat(paste0("Inclination: ", b1, ", \n", "Intercept: ", b0, "\n"))
Inclination: 2.16059976407543, 
Intercept: 0.128130381671001
> summary(lm(y~x))$coefficients
             Estimate Std. Error    t value     Pr(>|t|)
(Intercept) 0.1281304  0.7108462  0.1802506 8.570292e-01
x           2.1605998  0.1387908 15.5673144 8.229814e-45
> 
> ggplot(data, aes(x = x, y = y)) + 
+   geom_point(size = 2) + 
+   geom_abline(aes(intercept = b0s, slope = b1s),
+               data = parameters, linewidth = 0.5, color = 'red') + 
+   theme_classic() +
+   geom_abline(aes(intercept = b0s, slope = b1s), 
+               data = parameters %>% slice_head(), 
+               linewidth = 0.5, color = 'blue') + 
+   geom_abline(aes(intercept = b0s, slope = b1s), 
+               data = parameters %>% slice_tail(), 
+               linewidth = 1, color = 'green') +
+   labs(title = 'Gradient descent: blue: start, green: end')
> data
# A tibble: 500 × 2
       x     y
   <dbl> <dbl>
 1  4.48 11.1 
 2  6.45 10.2 
 3  6.41 11.7 
 4  5.35 15.8 
 5  5.17  8.84
 6  3.64  1.36
 7  6.35 10.9 
 8  3.30 10.7 
 9  6.30  6.98
10  3.81  5.22
# ℹ 490 more rows
# ℹ Use `print(n = ...)` to see more rows
> parameters
# A tibble: 80 × 3
      b0s   b1s   mse
    <dbl> <dbl> <dbl>
 1 -2.69   1.07 123. 
 2 -2.12   1.29  84.1
 3 -1.67   1.46  59.1
 4 -1.31   1.60  43.0
 5 -1.02   1.71  32.8
 6 -0.791  1.80  26.2
 7 -0.606  1.87  22.0
 8 -0.459  1.93  19.3
 9 -0.341  1.98  17.6
10 -0.247  2.01  16.5
# ℹ 70 more rows
# ℹ Use `print(n = ...)` to see more rows
> 

gradient_descent.txt · Last modified: 2025/08/01 13:53 by hkimscil

Donate Powered by PHP Valid HTML5 Valid CSS Driven by DokuWiki