> > library(ggplot2) > library(ggpmisc) > library(tidyverse) > library(data.table) > > # settle down > rm(list=ls()) > > ss <- function(x) { + return(sum((x-mean(x))^2)) + } > > # data preparation > set.seed(101) > nx <- 50 # variable x, sample size > mx <- 4.5 # mean of x > sdx <- mx * 0.56 # sd of x > x <- rnorm(nx, mx, sdx) # generating x > slp <- 4 # slop of x = coefficient, b > # y variable > y <- slp * x + rnorm(nx, 0, slp*3*sdx) > > data <- data.frame(x, y) > head(data) x y 1 3.678388 -20.070168 2 5.892204 15.268808 3 2.799142 28.672292 4 5.040186 -22.081593 5 5.283138 43.784059 6 7.458395 -1.954306 > > # check with regression > mo <- lm(y ~ x, data = data) > summary(mo) Call: lm(formula = y ~ x, data = data) Residuals: Min 1Q Median 3Q Max -58.703 -20.303 0.331 19.381 51.929 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -2.708 8.313 -0.326 0.74601 x 5.005 1.736 2.884 0.00587 ** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 28.54 on 48 degrees of freedom Multiple R-squared: 0.1477, Adjusted R-squared: 0.1299 F-statistic: 8.316 on 1 and 48 DF, p-value: 0.005867 > > # graph > ggplot(data = data, aes(x = x, y = y)) + + geom_point() + + stat_poly_line() + + stat_poly_eq(use_label(c("eq", "R2"))) + + theme_classic() > > # from what we know > # get covariance value > sp.yx <- sum((x-mean(x))*(y-mean(y))) > df.yx <- length(y)-1 > sp.yx/df.yx [1] 27.61592 > # check with cov function > cov(x,y) [1] 27.61592 > # correlation value > cov(x,y)/(sd(x)*sd(y)) [1] 0.3842686 > cor(x,y) [1] 0.3842686 > > # regression by hand > # b and a > b <- sp.yx / ss(x) # b2 <- cov(x,y)/var(x) > a <- mean(y) - b*(mean(x)) > a [1] -2.708294 > b [1] 5.004838 > > # check a and b value from the lm > summary(mo)$coefficient[1] [1] -2.708294 > summary(mo)$coefficient[2] [1] 5.004838 > summary(mo) Call: lm(formula = y ~ x, data = data) Residuals: Min 1Q Median 3Q Max -58.703 -20.303 0.331 19.381 51.929 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -2.708 8.313 -0.326 0.74601 x 5.005 1.736 2.884 0.00587 ** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 28.54 on 48 degrees of freedom Multiple R-squared: 0.1477, Adjusted R-squared: 0.1299 F-statistic: 8.316 on 1 and 48 DF, p-value: 0.005867 > > fit.yx <- a + b*x # predicted value of y from x data > res <- y - fit.yx # error residuals > reg <- fit.yx - mean(y) # error regressions > ss.res <- sum(res^2) > ss.reg <- sum(reg^2) > ss.res+ss.reg [1] 45864.4 > ss.tot <- ss(y) > ss.tot [1] 45864.4 > > plot(x,y) > abline(a, b, col="red", lwd=2) > plot(x, fit.yx) > plot(x, res) > > df.y <- length(y)-1 > df.reg <- 2-1 > df.res <- df.y - df.reg > df.res [1] 48 > > r.sq <- ss.reg / ss.tot > r.sq [1] 0.1476624 > summary(mo)$r.square [1] 0.1476624 > ms.reg <- ss.reg / df.reg > ms.res <- ss.res / df.res > > > f.cal <- ms.reg / ms.res > f.cal [1] 8.315713 > pf(f.cal, df.reg, df.res,lower.tail = F) [1] 0.005867079 > t.cal <- sqrt(f.cal) > t.cal [1] 2.883698 > se.b <- sqrt(ms.res/ss(x)) > se.b [1] 1.735562 > t.cal <- (b-0)/se.b > t.cal [1] 2.883698 > pt(t.cal, df=df.res, lower.tail = F)*2 [1] 0.005867079 > summary(mo) Call: lm(formula = y ~ x, data = data) Residuals: Min 1Q Median 3Q Max -58.703 -20.303 0.331 19.381 51.929 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -2.708 8.313 -0.326 0.74601 x 5.005 1.736 2.884 0.00587 ** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 28.54 on 48 degrees of freedom Multiple R-squared: 0.1477, Adjusted R-squared: 0.1299 F-statistic: 8.316 on 1 and 48 DF, p-value: 0.005867 > > > # getting a and b from > # gradient descent > a <- rnorm(1) > b <- rnorm(1) > a.start <- a > b.start <- b > a.start [1] 0.2680658 > b.start [1] -0.5922083 > > # Predict function: > predict <- function(x, a, b){ + return (a + b * x) + } > > # And loss function is: > residuals <- function(fit, y) { + return(y - fit) + } > > gradient <- function(x, res){ + db = -2 * mean(x * res) + da = -2 * mean(res) + return(list("b" = db, "a" = da)) + } > > # to check ms.residual > msrloss <- function(fit, y) { + res <- residuals(fit, y) + return(mean(res^2)) + } > > # Train the model with scaled features > learning.rate = 1e-1 # 0.1 > > # Record Loss for each epoch: > as = c() > bs = c() > msrs = c() > ssrs = c() > mres = c() > zx <- (x-mean(x))/sd(x) > > nlen <- 75 > for (epoch in 1:nlen) { + fit.val <- predict(zx, a, b) + residual <- residuals(fit.val, y) + loss <- msrloss(fit.val, y) + mres <- append(mres, mean(residual)) + msrs <- append(msrs, loss) + + grad <- gradient(zx, residual) + step.b <- grad$b * learning.rate + step.a <- grad$a * learning.rate + b <- b-step.b + a <- a-step.a + + as <- append(as, a) + bs <- append(bs, b) + } > msrs [1] 1254.6253 1085.3811 976.7258 906.9672 862.1801 833.4247 814.9621 803.1078 795.4963 790.6089 [11] 787.4707 785.4556 784.1615 783.3306 782.7970 782.4543 782.2342 782.0929 782.0021 781.9438 [21] 781.9064 781.8823 781.8669 781.8569 781.8506 781.8465 781.8439 781.8422 781.8411 781.8404 [31] 781.8399 781.8396 781.8395 781.8393 781.8393 781.8392 781.8392 781.8392 781.8392 781.8391 [41] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 [51] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 [61] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 [71] 781.8391 781.8391 781.8391 781.8391 781.8391 > mres [1] 1.798187e+01 1.438549e+01 1.150839e+01 9.206716e+00 7.365373e+00 5.892298e+00 4.713838e+00 [8] 3.771071e+00 3.016857e+00 2.413485e+00 1.930788e+00 1.544631e+00 1.235704e+00 9.885636e-01 [15] 7.908509e-01 6.326807e-01 5.061446e-01 4.049156e-01 3.239325e-01 2.591460e-01 2.073168e-01 [22] 1.658534e-01 1.326828e-01 1.061462e-01 8.491697e-02 6.793357e-02 5.434686e-02 4.347749e-02 [29] 3.478199e-02 2.782559e-02 2.226047e-02 1.780838e-02 1.424670e-02 1.139736e-02 9.117890e-03 [36] 7.294312e-03 5.835449e-03 4.668360e-03 3.734688e-03 2.987750e-03 2.390200e-03 1.912160e-03 [43] 1.529728e-03 1.223782e-03 9.790260e-04 7.832208e-04 6.265766e-04 5.012613e-04 4.010090e-04 [50] 3.208072e-04 2.566458e-04 2.053166e-04 1.642533e-04 1.314026e-04 1.051221e-04 8.409769e-05 [57] 6.727815e-05 5.382252e-05 4.305802e-05 3.444641e-05 2.755713e-05 2.204570e-05 1.763656e-05 [64] 1.410925e-05 1.128740e-05 9.029921e-06 7.223936e-06 5.779149e-06 4.623319e-06 3.698655e-06 [71] 2.958924e-06 2.367140e-06 1.893712e-06 1.514969e-06 1.211975e-06 > as [1] 3.864439 6.741538 9.043217 10.884560 12.357635 13.536094 14.478862 15.233076 15.836447 16.319144 [11] 16.705302 17.014228 17.261369 17.459082 17.617252 17.743788 17.845017 17.926000 17.990787 18.042616 [21] 18.084079 18.117250 18.143786 18.165016 18.181999 18.195586 18.206455 18.215151 18.222107 18.227672 [31] 18.232124 18.235686 18.238535 18.240815 18.242638 18.244097 18.245264 18.246198 18.246945 18.247542 [41] 18.248021 18.248403 18.248709 18.248954 18.249149 18.249306 18.249431 18.249532 18.249612 18.249676 [51] 18.249727 18.249768 18.249801 18.249828 18.249849 18.249865 18.249879 18.249890 18.249898 18.249905 [61] 18.249911 18.249915 18.249919 18.249921 18.249924 18.249925 18.249927 18.249928 18.249929 18.249930 [71] 18.249930 18.249931 18.249931 18.249931 18.249932 > bs [1] 1.828121 3.774066 5.338606 6.596496 7.607839 8.420960 9.074708 9.600322 10.022916 10.362681 [11] 10.635852 10.855482 11.032064 11.174036 11.288182 11.379955 11.453741 11.513064 11.560760 11.599108 [21] 11.629940 11.654728 11.674658 11.690682 11.703565 11.713923 11.722251 11.728946 11.734330 11.738658 [31] 11.742138 11.744935 11.747185 11.748993 11.750447 11.751616 11.752556 11.753312 11.753920 11.754408 [41] 11.754801 11.755117 11.755370 11.755575 11.755739 11.755871 11.755977 11.756062 11.756131 11.756186 [51] 11.756230 11.756266 11.756294 11.756317 11.756336 11.756351 11.756363 11.756372 11.756380 11.756386 [61] 11.756391 11.756395 11.756399 11.756401 11.756403 11.756405 11.756406 11.756407 11.756408 11.756409 [71] 11.756410 11.756410 11.756410 11.756411 11.756411 > > # scaled > a [1] 18.24993 > b [1] 11.75641 > > # unscale coefficients to make them comprehensible > # see http://commres.net/wiki/gradient_descent#why_normalize_scale_or_make_z-score_xi > # and > # http://commres.net/wiki/gradient_descent#how_to_unnormalize_unscale_a_and_b > # > a = a - (mean(x) / sd(x)) * b > b = b / sd(x) > a [1] -2.708293 > b [1] 5.004837 > > # changes of estimators > as <- as - (mean(x) /sd(x)) * bs > bs <- bs / sd(x) > > as [1] 0.60543638 0.01348719 -0.47394836 -0.87505325 -1.20490696 -1.47600164 -1.69867560 -1.88147654 [9] -2.03146535 -2.15446983 -2.25529623 -2.33790528 -2.40555867 -2.46094055 -2.50625843 -2.54332669 [17] -2.57363572 -2.59840909 -2.61865082 -2.63518431 -2.64868455 -2.65970460 -2.66869741 -2.67603377 [25] -2.68201712 -2.68689566 -2.69087236 -2.69411311 -2.69675345 -2.69890410 -2.70065549 -2.70208142 [33] -2.70324211 -2.70418670 -2.70495527 -2.70558050 -2.70608902 -2.70650253 -2.70683873 -2.70711203 [41] -2.70733415 -2.70751464 -2.70766129 -2.70778042 -2.70787718 -2.70795575 -2.70801956 -2.70807135 [49] -2.70811340 -2.70814753 -2.70817522 -2.70819769 -2.70821592 -2.70823071 -2.70824271 -2.70825244 [57] -2.70826033 -2.70826672 -2.70827191 -2.70827611 -2.70827952 -2.70828228 -2.70828452 -2.70828634 [65] -2.70828781 -2.70828900 -2.70828996 -2.70829074 -2.70829137 -2.70829189 -2.70829230 -2.70829264 [73] -2.70829291 -2.70829313 -2.70829331 > bs [1] 0.7782519 1.6066627 2.2727050 2.8082030 3.2387434 3.5848979 3.8632061 4.0869659 4.2668688 4.4115107 [11] 4.5278028 4.6213016 4.6964747 4.7569138 4.8055069 4.8445757 4.8759871 4.9012418 4.9215466 4.9378716 [21] 4.9509970 4.9615498 4.9700342 4.9768557 4.9823401 4.9867497 4.9902949 4.9931453 4.9954370 4.9972795 [31] 4.9987609 4.9999520 5.0009096 5.0016795 5.0022985 5.0027962 5.0031963 5.0035180 5.0037767 5.0039846 [41] 5.0041518 5.0042863 5.0043943 5.0044812 5.0045511 5.0046073 5.0046524 5.0046887 5.0047179 5.0047414 [51] 5.0047603 5.0047754 5.0047876 5.0047974 5.0048053 5.0048117 5.0048168 5.0048209 5.0048242 5.0048268 [61] 5.0048289 5.0048307 5.0048320 5.0048331 5.0048340 5.0048347 5.0048353 5.0048358 5.0048362 5.0048365 [71] 5.0048367 5.0048369 5.0048370 5.0048372 5.0048373 > mres [1] 1.798187e+01 1.438549e+01 1.150839e+01 9.206716e+00 7.365373e+00 5.892298e+00 4.713838e+00 [8] 3.771071e+00 3.016857e+00 2.413485e+00 1.930788e+00 1.544631e+00 1.235704e+00 9.885636e-01 [15] 7.908509e-01 6.326807e-01 5.061446e-01 4.049156e-01 3.239325e-01 2.591460e-01 2.073168e-01 [22] 1.658534e-01 1.326828e-01 1.061462e-01 8.491697e-02 6.793357e-02 5.434686e-02 4.347749e-02 [29] 3.478199e-02 2.782559e-02 2.226047e-02 1.780838e-02 1.424670e-02 1.139736e-02 9.117890e-03 [36] 7.294312e-03 5.835449e-03 4.668360e-03 3.734688e-03 2.987750e-03 2.390200e-03 1.912160e-03 [43] 1.529728e-03 1.223782e-03 9.790260e-04 7.832208e-04 6.265766e-04 5.012613e-04 4.010090e-04 [50] 3.208072e-04 2.566458e-04 2.053166e-04 1.642533e-04 1.314026e-04 1.051221e-04 8.409769e-05 [57] 6.727815e-05 5.382252e-05 4.305802e-05 3.444641e-05 2.755713e-05 2.204570e-05 1.763656e-05 [64] 1.410925e-05 1.128740e-05 9.029921e-06 7.223936e-06 5.779149e-06 4.623319e-06 3.698655e-06 [71] 2.958924e-06 2.367140e-06 1.893712e-06 1.514969e-06 1.211975e-06 > msrs [1] 1254.6253 1085.3811 976.7258 906.9672 862.1801 833.4247 814.9621 803.1078 795.4963 790.6089 [11] 787.4707 785.4556 784.1615 783.3306 782.7970 782.4543 782.2342 782.0929 782.0021 781.9438 [21] 781.9064 781.8823 781.8669 781.8569 781.8506 781.8465 781.8439 781.8422 781.8411 781.8404 [31] 781.8399 781.8396 781.8395 781.8393 781.8393 781.8392 781.8392 781.8392 781.8392 781.8391 [41] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 [51] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 [61] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 [71] 781.8391 781.8391 781.8391 781.8391 781.8391 > > parameters <- data.frame(as, bs, mres, msrs) > > cat(paste0("Intercept: ", a, "\n", "Slope: ", b, "\n")) Intercept: -2.7082933069293 Slope: 5.00483726695576 > > summary(mo)$coefficients Estimate Std. Error t value Pr(>|t|) (Intercept) -2.708294 8.313223 -0.3257815 0.746005708 x 5.004838 1.735562 2.8836978 0.005867079 > > msrs <- data.frame(msrs) > msrs.log <- data.table(epoch = 1:nlen, msrs) > ggplot(msrs.log, aes(epoch, msrs)) + + geom_line(color="blue") + + theme_classic() > > mres <- data.frame(mres) > mres.log <- data.table(epoch = 1:nlen, mres) > ggplot(mres.log, aes(epoch, mres)) + + geom_line(color="red") + + theme_classic() > > ch <- data.frame(mres, msrs) > ch mres msrs 1 1.798187e+01 1254.6253 2 1.438549e+01 1085.3811 3 1.150839e+01 976.7258 4 9.206716e+00 906.9672 5 7.365373e+00 862.1801 6 5.892298e+00 833.4247 7 4.713838e+00 814.9621 8 3.771071e+00 803.1078 9 3.016857e+00 795.4963 10 2.413485e+00 790.6089 11 1.930788e+00 787.4707 12 1.544631e+00 785.4556 13 1.235704e+00 784.1615 14 9.885636e-01 783.3306 15 7.908509e-01 782.7970 16 6.326807e-01 782.4543 17 5.061446e-01 782.2342 18 4.049156e-01 782.0929 19 3.239325e-01 782.0021 20 2.591460e-01 781.9438 21 2.073168e-01 781.9064 22 1.658534e-01 781.8823 23 1.326828e-01 781.8669 24 1.061462e-01 781.8569 25 8.491697e-02 781.8506 26 6.793357e-02 781.8465 27 5.434686e-02 781.8439 28 4.347749e-02 781.8422 29 3.478199e-02 781.8411 30 2.782559e-02 781.8404 31 2.226047e-02 781.8399 32 1.780838e-02 781.8396 33 1.424670e-02 781.8395 34 1.139736e-02 781.8393 35 9.117890e-03 781.8393 36 7.294312e-03 781.8392 37 5.835449e-03 781.8392 38 4.668360e-03 781.8392 39 3.734688e-03 781.8392 40 2.987750e-03 781.8391 41 2.390200e-03 781.8391 42 1.912160e-03 781.8391 43 1.529728e-03 781.8391 44 1.223782e-03 781.8391 45 9.790260e-04 781.8391 46 7.832208e-04 781.8391 47 6.265766e-04 781.8391 48 5.012613e-04 781.8391 49 4.010090e-04 781.8391 50 3.208072e-04 781.8391 51 2.566458e-04 781.8391 52 2.053166e-04 781.8391 53 1.642533e-04 781.8391 54 1.314026e-04 781.8391 55 1.051221e-04 781.8391 56 8.409769e-05 781.8391 57 6.727815e-05 781.8391 58 5.382252e-05 781.8391 59 4.305802e-05 781.8391 60 3.444641e-05 781.8391 61 2.755713e-05 781.8391 62 2.204570e-05 781.8391 63 1.763656e-05 781.8391 64 1.410925e-05 781.8391 65 1.128740e-05 781.8391 66 9.029921e-06 781.8391 67 7.223936e-06 781.8391 68 5.779149e-06 781.8391 69 4.623319e-06 781.8391 70 3.698655e-06 781.8391 71 2.958924e-06 781.8391 72 2.367140e-06 781.8391 73 1.893712e-06 781.8391 74 1.514969e-06 781.8391 75 1.211975e-06 781.8391 > max(y) [1] 83.02991 > ggplot(data, aes(x = x, y = y)) + + geom_point(size = 2) + + geom_abline(aes(intercept = as, slope = bs), + data = parameters, linewidth = 0.5, + color = 'green') + + stat_poly_line() + + stat_poly_eq(use_label(c("eq", "R2"))) + + theme_classic() + + geom_abline(aes(intercept = as, slope = bs), + data = parameters %>% slice_head(), + linewidth = 1, color = 'blue') + + geom_abline(aes(intercept = as, slope = bs), + data = parameters %>% slice_tail(), + linewidth = 1, color = 'red') + + labs(title = 'Gradient descent. blue: start, red: end, green: gradients') > summary(mo) Call: lm(formula = y ~ x, data = data) Residuals: Min 1Q Median 3Q Max -58.703 -20.303 0.331 19.381 51.929 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -2.708 8.313 -0.326 0.74601 x 5.005 1.736 2.884 0.00587 ** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 28.54 on 48 degrees of freedom Multiple R-squared: 0.1477, Adjusted R-squared: 0.1299 F-statistic: 8.316 on 1 and 48 DF, p-value: 0.005867 > a.start [1] 0.2680658 > b.start [1] -0.5922083 > a [1] -2.708293 > b [1] 5.004837 > summary(mo)$coefficient[1] [1] -2.708294 > summary(mo)$coefficient[2] [1] 5.004838 >