gradient_descent:output01
This is an old revision of the document!
> library(ggplot2)
> library(ggpmisc)
>
> rm(list=ls())
> # set.seed(191)
> nx <- 200
> mx <- 4.5
> sdx <- mx * 0.56
> x <- rnorm(nx, mx, sdx)
> slp <- 12
> y <- slp * x + rnorm(nx, 0, slp*sdx*3)
>
> data <- data.frame(x, y)
>
> mo <- lm(y ~ x, data = data)
> summary(mo)
Call:
lm(formula = y ~ x, data = data)
Residuals:
Min 1Q Median 3Q Max
-259.314 -59.215 6.683 58.834 309.833
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 8.266 12.546 0.659 0.511
x 11.888 2.433 4.887 2.11e-06 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 88.57 on 198 degrees of freedom
Multiple R-squared: 0.1076, Adjusted R-squared: 0.1031
F-statistic: 23.88 on 1 and 198 DF, p-value: 2.111e-06
>
> ggplot(data = data, aes(x = x, y = y)) +
+ geom_point() +
+ stat_poly_line() +
+ stat_poly_eq(use_label(c("eq", "R2"))) +
+ theme_classic()
> # set.seed(191)
> # Initialize random betas
> # 우선 b를 고정하고 a만
> # 변화시켜서 이해
> b <- summary(mo)$coefficients[2]
> a <- 0
>
> b.init <- b
> a.init <- a
>
> # Predict function:
> predict <- function(x, a, b){
+ return (a + b * x)
+ }
>
> # And loss function is:
> residuals <- function(predictions, y) {
+ return(y - predictions)
+ }
>
> # we use sum of square of error which oftentimes become big
> ssrloss <- function(predictions, y) {
+ residuals <- (y - predictions)
+ return(sum(residuals^2))
+ }
>
> ssrs <- c() # for sum of square residuals
> srs <- c() # sum of residuals
> as <- c() # for as (intercepts)
>
> for (i in seq(from = -50, to = 50, by = 0.01)) {
+ pred <- predict(x, i, b)
+ res <- residuals(pred, y)
+ ssr <- ssrloss(pred, y)
+ ssrs <- append(ssrs, ssr)
+ srs <- append(srs, sum(res))
+ as <- append(as, i)
+ }
> length(ssrs)
[1] 10001
> length(srs)
[1] 10001
> length(as)
[1] 10001
>
> min(ssrs)
[1] 1553336
> min.pos.ssrs <- which(ssrs == min(ssrs))
> min.pos.ssrs
[1] 5828
> print(as[min.pos.ssrs])
[1] 8.27
> summary(mo)
Call:
lm(formula = y ~ x, data = data)
Residuals:
Min 1Q Median 3Q Max
-259.314 -59.215 6.683 58.834 309.833
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 8.266 12.546 0.659 0.511
x 11.888 2.433 4.887 2.11e-06 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 88.57 on 198 degrees of freedom
Multiple R-squared: 0.1076, Adjusted R-squared: 0.1031
F-statistic: 23.88 on 1 and 198 DF, p-value: 2.111e-06
> plot(seq(1, length(ssrs)), ssrs)
> plot(seq(1, length(ssrs)), srs)
> tail(ssrs)
[1] 1900842 1901008 1901175 1901342 1901509 1901676
> max(ssrs)
[1] 2232329
> min(ssrs)
[1] 1553336
> tail(srs)
[1] -8336.735 -8338.735 -8340.735 -8342.735 -8344.735 -8346.735
> max(srs)
[1] 11653.26
> min(srs)
[1] -8346.735
>
>
gradient_descent/output01.1766083965.txt.gz · Last modified: by hkimscil
