tidymodels / brulee Goto Github PK
View Code? Open in Web Editor NEWHigh-Level Modeling Functions with 'torch'
Home Page: https://brulee.tidymodels.org/
License: Other
High-Level Modeling Functions with 'torch'
Home Page: https://brulee.tidymodels.org/
License: Other
Right now, for simple generalized linear models, the slopes and intercepts are returned in different arrays.
Finishing this depends on #32
Add an argument called hidden
, a vector with the number of hidden units for each hidden layer. The number of hidden layers would be equal to the length(hidden)
.
These were removed from the initial version since I had trouble with convergence.
When predicting on new data with the updated torch version 0.11.0 I get an error !self$..refer_to_state_dict.. : invalid argument type
. I believe it is linked to the breaking change mentioned below.
install.packages("torch") # Version 0.11.0
torch::install_torch(reinstall = TRUE)
library(brulee)
library(recipes)
library(yardstick)
data(bivariate, package = "modeldata")
set.seed(20)
nn_log_biv <- brulee_mlp(Class ~ log(A) + log(B), data = bivariate_train,
epochs = 150, hidden_units = 3)
# We use the tidymodels semantics to always return a tibble when predicting
predict(nn_log_biv, bivariate_test, type = "prob") %>%
bind_cols(bivariate_test) %>%
roc_auc(Class, .pred_One)
# Error in !self$..refer_to_state_dict.. : invalid argument type
Prepare for release:
git pull
devtools::build_readme()
urlchecker::url_check()
devtools::check(remote = TRUE, manual = TRUE)
devtools::check_win_devel()
rhub::check_for_cran()
revdepcheck::cloud_check()
cran-comments.md
git push
Submit to CRAN:
usethis::use_version('minor')
devtools::submit_cran()
Wait for CRAN...
git push
usethis::use_github_release()
usethis::use_dev_version()
git push
I cannot use linear activation functions in brulee_mlp()
. I instead get my favorite R error :)
library(brulee)
library(recipes)
library(yardstick)
data(bivariate, package = "modeldata")
set.seed(20)
nn_log_biv <- brulee_mlp(Class ~ log(A) + log(B), data = bivariate_train,
epochs = 150, hidden_units = 3, activation = "linear")
#> Error in x$parameters: object of type 'closure' is not subsettable
R version 4.3.1 (2023-06-16 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 11 x64 (build 22621)
Matrix products: default
locale:
[1] LC_COLLATE=English_Austria.utf8 LC_CTYPE=English_Austria.utf8
[3] LC_MONETARY=English_Austria.utf8 LC_NUMERIC=C
[5] LC_TIME=English_Austria.utf8
time zone: Europe/Vienna
tzcode source: internal
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] yardstick_1.2.0 brulee_0.2.0 recipes_1.0.6 dplyr_1.1.2
loaded via a namespace (and not attached):
[1] utf8_1.2.3 future_1.32.0 generics_0.1.3
[4] class_7.3-22 lattice_0.21-8 listenv_0.9.0
[7] digest_0.6.31 magrittr_2.0.3 grid_4.3.1
[10] timechange_0.2.0 rprojroot_2.0.3 Matrix_1.5-4.1
[13] processx_3.8.2 nnet_7.3-19 survival_3.5-5
[16] torch_0.11.0 ps_1.7.5 purrr_1.0.1
[19] fansi_1.0.4 scales_1.2.1 coro_1.0.3
[22] codetools_0.2-19 lava_1.7.2.1 cli_3.6.1
[25] rlang_1.1.1 hardhat_1.3.0 parallelly_1.36.0
[28] future.apply_1.11.0 munsell_0.5.0 bit64_4.0.5
[31] splines_4.3.1 withr_2.5.0 prodlim_2023.03.31
[34] tools_4.3.1 parallel_4.3.1 colorspace_2.1-0
[37] ggplot2_3.4.3 globals_0.16.2 vctrs_0.6.2
[40] R6_2.5.1 rpart_4.1.19 lifecycle_1.0.3
[43] lubridate_1.9.2 bit_4.0.5 MASS_7.3-60
[46] desc_1.4.2 pkgconfig_2.0.3 callr_3.7.3
[49] gtable_0.3.3 pillar_1.9.0 data.table_1.14.8
[52] glue_1.6.2 Rcpp_1.0.10 tibble_3.2.1
[55] tidyselect_1.2.0 ipred_0.9-14 timeDate_4022.108
[58] gower_1.0.1 compiler_4.3.1
Currently, relu
and elu
are there but tanh
, linear
, and softmax would also be good ideas.
example here and here.
They are currently in different files
Prepare for release:
git pull
devtools::build_readme()
urlchecker::url_check()
devtools::check(remote = TRUE, manual = TRUE)
devtools::check_win_devel()
rhub::check_for_cran()
revdepcheck::cloud_check()
cran-comments.md
git push
Submit to CRAN:
usethis::use_version('minor')
devtools::submit_cran()
Wait for CRAN...
git push
usethis::use_github_release()
usethis::use_dev_version()
git push
I'm having trouble with training mlp()
specification by brulee
. I know that brulee uses torch and I checked my torch & gpu relationship, seems okay. But in the example below, training goes on CPU.
suppressPackageStartupMessages({
library(tidymodels)
library(torch)
})
torch::cuda_is_available()
#> TRUE
torch::cuda_device_count()
#> [1] 1
set.seed(1)
modspec <- mlp(hidden_units = tune(),
penalty = tune(),
epochs = tune(),
activation = tune(),
learn_rate = tune()) %>%
set_mode('classification') %>%
set_engine('brulee')
fk_param <- modspec %>%
extract_parameter_set_dials %>%
grid_max_entropy(size = 50)
spl_obj <- initial_split(iris,.7)
cv_obj <- vfold_cv(training(spl_obj),5)
rcp <- recipe(formula = Species ~
Sepal.Width +
Sepal.Length +
Petal.Width +
Petal.Length,
data = training(spl_obj)) %>%
step_normalize(all_numeric_predictors())
wf <- workflow() %>%
add_model(modspec) %>%
add_recipe(rcp)
cv_fit <- wf %>%
tune_grid(resamples = cv_obj,grid = fk_param)
We want to be able to get predictions from previous models (over iterations). Right now, we use model_to_raw()
to save it. For a relatively simple neural network model with 12 parameters, each object is about 900K in memory. If there are many iterations, the model fit can be huge.
Is it feasible to store a single model object (say at the last epoch), save the parameters at each epoch, then stuff them back into the model object at prediction time?
It would be nice to support L1 and L2 regularization
For whatever reason, linear regression is not converging to the coefficients found using the analytic solution.
library(tidymodels)
#> ββ Attaching packages ββββββββββββββββββββββββββββββββββββββ tidymodels 0.1.2 ββ
#> β broom 0.7.2 β recipes 0.1.15
#> β dials 0.0.9.9000 β rsample 0.0.8
#> β dplyr 1.0.2 β tibble 3.0.4
#> β ggplot2 3.3.2 β tidyr 1.1.2
#> β infer 0.5.3 β tune 0.1.2
#> β modeldata 0.1.0.9000 β workflows 0.2.1
#> β parsnip 0.1.4.9000 β yardstick 0.0.7
#> β purrr 0.3.4
#> ββ Conflicts βββββββββββββββββββββββββββββββββββββββββ tidymodels_conflicts() ββ
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
library(lantern)
rec <- recipe(mpg ~ ., data = mtcars) %>% step_normalize(all_predictors())
set.seed(1)
torch_2 <- lantern_linear_reg(rec, data = mtcars, epochs = 2)
set.seed(1)
torch_10 <- lantern_linear_reg(rec, data = mtcars, epochs = 10)
set.seed(1)
torch_100 <- lantern_linear_reg(rec, data = mtcars, epochs = 100)
set.seed(1)
torch_1000 <- lantern_linear_reg(rec, data = mtcars, epochs = 1000)
lm_fit <- lm(mpg ~ ., data = rec %>% prep() %>% juice())
tibble(
term = names(coef(lm_fit)),
lm = coef(lm_fit),
`2_iter` = torch_2$coefs,
`10_iter` = torch_10$coefs,
`100_iter` = torch_100$coefs,
`1000_iter` = torch_1000$coefs
)
#> # A tibble: 11 x 6
#> term lm `2_iter` `10_iter` `100_iter` `1000_iter`
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 (Intercept) 20.1 20.1 20.1 20.1 20.1
#> 2 cyl -0.199 -0.328 -0.885 -0.600 -0.260
#> 3 disp 1.65 -0.0938 -0.585 -0.801 0.105
#> 4 hp -1.47 -0.0162 -0.572 -0.988 -0.905
#> 5 drat 0.421 0.332 0.680 0.664 0.493
#> 6 wt -3.64 -0.00650 -0.585 -1.46 -2.29
#> 7 qsec 1.47 0.0288 0.220 0.315 0.851
#> 8 vs 0.160 0.291 0.239 0.349 0.141
#> 9 am 1.26 0.0764 0.160 0.804 1.17
#> 10 gear 0.484 0.286 0.582 0.294 0.553
#> 11 carb -0.322 -0.414 -0.427 -0.783 -1.12
Created on 2020-12-03 by the reprex package (v0.3.0)
I'm having trouble using the coef function to extract coefficients from a brulee_logistic_reg when the model was specifying using a recipe and not using matrices.
An example adapted from the package
library(brulee)
library(recipes)
#> Le chargement a nΓ©cessitΓ© le package : dplyr
#>
#> Attachement du package : 'dplyr'
#> Les objets suivants sont masquΓ©s depuis 'package:stats':
#>
#> filter, lag
#> Les objets suivants sont masquΓ©s depuis 'package:base':
#>
#> intersect, setdiff, setequal, union
#>
#> Attachement du package : 'recipes'
#> L'objet suivant est masquΓ© depuis 'package:stats':
#>
#> step
library(yardstick)
data(cells, package = "modeldata")
cells$case <- NULL
set.seed(122)
in_train <- sample(1:nrow(cells), 1000)
cells_train <- cells[ in_train,]
cells_test <- cells[-in_train,]
# Using matrices
set.seed(1)
reg_from_matrices <- brulee_logistic_reg(x = as.matrix(cells_train[, c("fiber_width_ch_1", "width_ch_1")]),
y = cells_train$class,
penalty = 0.10, epochs = 3)
coef(reg_from_matrices)
#> (Intercept) fiber_width_ch_1 width_ch_1
#> -4.0048534 0.2109342 -0.2108857
# Using recipe
cells_rec <-
recipe(class ~ ., data = cells_train) %>%
step_YeoJohnson(all_numeric_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_pca(all_numeric_predictors(), num_comp = 10)
set.seed(2)
reg_from_recipe <- brulee_logistic_reg(cells_rec, data = cells_train,
penalty = .01, epochs = 5)
coef(reg_from_recipe)
#> Error in names(param) <- c("(Intercept)", object$dims$features): attribut 'names' [11] doit Γͺtre de mΓͺme longueur que le vecteur [3]
Created on 2023-08-22 with reprex v2.0.2
sessioninfo::session_info()
#> β Session info βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#> setting value
#> version R version 4.3.1 (2023-06-16)
#> os Ubuntu 22.04.2 LTS
#> system x86_64, linux-gnu
#> ui X11
#> language (EN)
#> collate fr_FR.UTF-8
#> ctype fr_FR.UTF-8
#> tz Europe/Paris
#> date 2023-08-22
#> pandoc 3.1.1 @ /usr/lib/rstudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
#>
#> β Packages βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#> package * version date (UTC) lib source
#> bit 4.0.5 2022-11-15 [1] CRAN (R 4.3.1)
#> bit64 4.0.5 2020-08-30 [1] CRAN (R 4.3.1)
#> brulee * 0.2.0.9000 2023-08-22 [1] Github (tidymodels/brulee@087129b)
#> callr 3.7.3 2022-11-02 [1] CRAN (R 4.3.1)
#> class 7.3-22 2023-05-03 [4] CRAN (R 4.3.1)
#> cli 3.6.1 2023-03-23 [1] CRAN (R 4.3.1)
#> codetools 0.2-19 2023-02-01 [4] CRAN (R 4.2.2)
#> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.1)
#> coro 1.0.3 2022-07-19 [1] CRAN (R 4.3.1)
#> data.table 1.14.8 2023-02-17 [1] CRAN (R 4.3.1)
#> digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.1)
#> dplyr * 1.1.2 2023-04-20 [1] CRAN (R 4.3.1)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.3.1)
#> evaluate 0.21 2023-05-05 [1] CRAN (R 4.3.1)
#> fansi 1.0.4 2023-01-22 [1] CRAN (R 4.3.1)
#> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.1)
#> fs 1.6.3 2023-07-20 [1] CRAN (R 4.3.1)
#> future 1.33.0 2023-07-01 [1] CRAN (R 4.3.1)
#> future.apply 1.11.0 2023-05-21 [1] CRAN (R 4.3.1)
#> generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.1)
#> ggplot2 3.4.3 2023-08-14 [1] CRAN (R 4.3.1)
#> globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.1)
#> glue 1.6.2 2022-02-24 [1] CRAN (R 4.3.1)
#> gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.1)
#> gtable 0.3.4 2023-08-21 [1] CRAN (R 4.3.1)
#> hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.3.1)
#> htmltools 0.5.5 2023-03-23 [1] CRAN (R 4.3.1)
#> ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.1)
#> knitr 1.43 2023-05-25 [1] CRAN (R 4.3.1)
#> lattice 0.21-8 2023-04-05 [4] CRAN (R 4.3.0)
#> lava 1.7.2.1 2023-02-27 [1] CRAN (R 4.3.1)
#> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.3.1)
#> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.1)
#> lubridate 1.9.2 2023-02-10 [1] CRAN (R 4.3.1)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.1)
#> MASS 7.3-60 2023-05-04 [4] CRAN (R 4.3.1)
#> Matrix 1.6-0 2023-07-08 [1] CRAN (R 4.3.1)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.3.1)
#> nnet 7.3-19 2023-05-03 [4] CRAN (R 4.3.1)
#> parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.1)
#> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.1)
#> pkgconfig 2.0.3 2019-09-22 [2] CRAN (R 4.0.2)
#> processx 3.8.2 2023-06-30 [1] CRAN (R 4.3.1)
#> prodlim 2023.03.31 2023-04-02 [1] CRAN (R 4.3.1)
#> ps 1.7.5 2023-04-18 [1] CRAN (R 4.3.1)
#> purrr 1.0.1 2023-01-10 [1] CRAN (R 4.3.1)
#> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.3.1)
#> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.3.1)
#> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.3.1)
#> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.3.1)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.1)
#> Rcpp 1.0.11 2023-07-06 [1] CRAN (R 4.3.1)
#> recipes * 1.0.6 2023-04-25 [1] CRAN (R 4.3.1)
#> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.3.1)
#> rlang 1.1.1 2023-04-28 [1] CRAN (R 4.3.1)
#> rmarkdown 2.23 2023-07-01 [1] CRAN (R 4.3.1)
#> rpart 4.1.19 2022-10-21 [4] CRAN (R 4.2.1)
#> rstudioapi 0.15.0 2023-07-07 [1] CRAN (R 4.3.1)
#> scales 1.2.1 2022-08-20 [1] CRAN (R 4.3.1)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.1)
#> styler 1.10.1 2023-06-05 [1] CRAN (R 4.3.1)
#> survival 3.5-5 2023-03-12 [4] CRAN (R 4.3.1)
#> tibble 3.2.1 2023-03-20 [1] CRAN (R 4.3.1)
#> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.3.1)
#> timechange 0.2.0 2023-01-11 [1] CRAN (R 4.3.1)
#> timeDate 4022.108 2023-01-07 [1] CRAN (R 4.3.1)
#> torch 0.11.0 2023-06-06 [1] CRAN (R 4.3.1)
#> utf8 1.2.3 2023-01-31 [1] CRAN (R 4.3.1)
#> vctrs 0.6.3 2023-06-14 [1] CRAN (R 4.3.1)
#> withr 2.5.0 2022-03-03 [1] CRAN (R 4.3.1)
#> xfun 0.39 2023-04-20 [1] CRAN (R 4.3.1)
#> yaml 2.2.1 2020-02-01 [2] CRAN (R 4.0.2)
#> yardstick * 1.2.0 2023-04-21 [1] CRAN (R 4.3.1)
#>
#> [1] /home/jaubert/R/x86_64-pc-linux-gnu-library/4.3
#> [2] /usr/local/lib/R/site-library
#> [3] /usr/lib/R/site-library
#> [4] /usr/lib/R/library
#>
#> ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
This would be much easier to tune
From @dfalbel...
We might want to implement a learning rate decay technique otherwise it's too hard to tune the learning rate correctly.
Maybe something like what sklearn does:
βinvscalingβ gradually decreases the learning rate at each time step βtβ using an inverse scaling exponent of βpower_tβ. effective_learning_rate = learning_rate_init / pow(t, power_t)
βadaptiveβ keeps the learning rate constant to βlearning_rate_initβ as long as training loss keeps decreasing. Each time two consecutive epochs fail to decrease training loss by at least tol, or fail to increase validation score by at least tol if βearly_stoppingβ is on, the current learning rate is divided by 5.
Currently, the model structure is a neural network with no hidden units. This leads to a valid model but has more parameters than required for this simple model:
$fc1.weight
[,1] [,2]
[1,] -0.6381196 0.9604316
[2,] 0.5626327 -0.2731194
$fc1.bias
[1] 0.3900663 0.3431486
So for p
model terms we can create a model with p + 1
parameters.
This should also be the case for multinomial regression (in #31).
for CRAN release
Prepare for release:
devtools::build_readme()
urlchecker::url_check()
devtools::check(remote = TRUE, manual = TRUE)
devtools::check_win_devel()
rhub::check_for_cran()
revdepcheck::cloud_check()
cran-comments.md
Submit to CRAN:
usethis::use_version('minor')
devtools::submit_cran()
Wait for CRAN...
usethis::use_github_release()
usethis::use_dev_version()
I suggest to simplify matrix_to_dataset()
using torch::tensor_dataset()
, which also check the dimensions.
This will replace the following lines
https://github.com/tidymodels/lantern/blob/db48129496d9fe1ffcf85e1a595ca9d5c8e85461/R/convert_data.R#L15-L42
with something like:
matrix_to_dataset <- function(x, y) {
x <- torch::torch_tensor(x)
if (is.factor(y)) {
y <- as.numeric(y)
y <- torch::torch_tensor(y, dtype = torch_long())
} else {
y <- torch::torch_tensor(y)
}
torch::tensor_dataset(x = x, y = y)
}
If you agree, I can create a PR.
For this code
library(lantern)
set.seed(1)
df <- tibble::tibble(
x1 = runif(100),
x2 = runif(100),
y = 3 + 2*x1 + 3*x2
)
# Log for an upcoming issue:
set.seed(1)
lantern_linear_reg(y ~ ., df, epochs = 2, verbose = TRUE)
GHA give different results on different operating systems
macOS:
epoch: 1 Loss (scaled): 1.57e-12
epoch: 2 Loss (scaled): 1.57e-12 β
ubuntu:
epoch: 1 Loss (scaled): 1.46e-12
epoch: 2 Loss (scaled): 1.46e-12 β
windows
epoch: 1 Loss (scaled): 1.57e-12
epoch: 2 Loss (scaled): 1.57e-12 x
Since this package is on CRAN now it should get a proper URL. This is currently "breaking" https://www.tidymodels.org/find/#search-parsnip-models since it links to the wrong place.
restrict lantern_logistic_reg()
to two classes.
When choosing LBFGS optimizer in brulee_mlp() it gets set in the following part of mlp_fit_imp ():
Lines 660 to 668 in a94ec7a
However, in each epoch, the optimizer gets overwritten with an SGD optimizer:
Line 697 in a94ec7a
This leads to the LBFGS optimizer never being used.
Is this an intended behaviour ?
Currently, brulee_mlp()
supports "relu", "elu", "tanh", and "linear" activation functions. Would be great to also have "sigmoid" and "softmax" as they are also supported via torch::nn_sigmoid()
and torch::nn_softmax()
. I assume it just a few additional lines in the get_activation_fn()
: https://github.com/tidymodels/brulee/blob/087129b0a71e63f16137934f89091b4db7fa4351/R/mlp-fit.R#L865C22-L865C22
They are not reproducible across hardware and operating systems. Use equivelance tests with good tolerances to unit/regression test.
Includes: initial
, decay
, reduction
, steps
, largest
, and step_size
. Maybe preface the function names with rate_
.
Also update the tunable()
method be be able to seamlessly use them.
It looks like the log-loss function allows for class weights.
2023
Necessary:
person(given = "Posit Software, PBC", role = c("cph", "fnd"))
use_mit_license()
use_tidy_logo()
usethis::use_tidy_coc()
usethis::use_tidy_github_actions()
Optional:
pak::pak("org/pkg")
over devtools::install_github("org/pkg")
in READMEuse_tidy_dependencies()
and/or replace compat files with use_standalone()
use_standalone("r-lib/rlang", "types-check")
instead of home grown argument checkersPrepare for release:
git pull
urlchecker::url_check()
devtools::build_readme()
devtools::check(remote = TRUE, manual = TRUE)
devtools::check_win_devel()
revdepcheck::cloud_check()
cran-comments.md
git push
Submit to CRAN:
usethis::use_version('minor')
devtools::submit_cran()
Wait for CRAN...
usethis::use_github_release()
usethis::use_dev_version(push = TRUE)
Somewhat due to randomness; different random numbers may not fail
library(tidymodels)
library(brulee)
tidymodels_prefer()
data(ames, package = "modeldata")
ames$Sale_Price <- log10(ames$Sale_Price)
set.seed(122)
in_train <- sample(1:nrow(ames), 2000)
ames_train <- ames[ in_train,]
ames_test <- ames[-in_train,]
set.seed(1)
brulee_linear_reg(x = as.matrix(ames_train[, c("Longitude", "Latitude")]),
y = ames_train$Sale_Price,
penalty = 0.10, epochs = 10, batch_size = 64,
optimizer = "SGD", verbose = TRUE)
#> Warning: Current loss in NaN. Training wil be stopped.
#> Linear regression
#>
#> 2,000 samples, 2 features, numeric outcome
#> weight decay: 0.1
#> batch size: 64
#> scaled validation loss after 1 epoch: NaN
Created on 2023-11-02 with reprex v2.0.2
For each model, have a list of acceptable loss functions and an argument to switch.
> ls(pattern = "loss$", envir = asNamespace("torch"))
[1] "nn_adaptive_log_softmax_with_loss" "nn_bce_loss"
[3] "nn_bce_with_logits_loss" "nn_cosine_embedding_loss"
[5] "nn_cross_entropy_loss" "nn_ctc_loss"
[7] "nn_hinge_embedding_loss" "nn_kl_div_loss"
[9] "nn_l1_loss" "nn_loss"
[11] "nn_margin_ranking_loss" "nn_mse_loss"
[13] "nn_multi_margin_loss" "nn_multilabel_margin_loss"
[15] "nn_multilabel_soft_margin_loss" "nn_nll_loss"
[17] "nn_poisson_nll_loss" "nn_smooth_l1_loss"
[19] "nn_soft_margin_loss" "nn_triplet_margin_loss"
[21] "nn_triplet_margin_with_distance_loss" "nn_weighted_loss"
[23] "nnf_cosine_embedding_loss" "nnf_ctc_loss"
[25] "nnf_hinge_embedding_loss" "nnf_l1_loss"
[27] "nnf_margin_ranking_loss" "nnf_mse_loss"
[29] "nnf_multi_margin_loss" "nnf_multilabel_margin_loss"
[31] "nnf_multilabel_soft_margin_loss" "nnf_nll_loss"
[33] "nnf_poisson_nll_loss" "nnf_smooth_l1_loss"
[35] "nnf_soft_margin_loss" "nnf_triplet_margin_loss"
[37] "nnf_triplet_margin_with_distance_loss" "torch__ctc_loss"
[39] "torch__cudnn_ctc_loss" "torch__use_cudnn_ctc_loss"
[41] "torch_cosine_embedding_loss" "torch_cross_entropy_loss"
[43] "torch_ctc_loss" "torch_hinge_embedding_loss"
[45] "torch_huber_loss" "torch_l1_loss"
[47] "torch_margin_ranking_loss" "torch_mse_loss"
[49] "torch_multi_margin_loss" "torch_multilabel_margin_loss"
[51] "torch_nll_loss" "torch_poisson_nll_loss"
[53] "torch_smooth_l1_loss" "torch_soft_margin_loss"
[55] "torch_triplet_margin_loss"
Implement a torch version of caret:::GarsonWeights()
.
Right now we convert the model coefficients to R arrays after the fit (so that they can be save in RData objects).
For prediction, I think that we'll need to recreate the torch model object with the correct parameter values, then use it to make predictions.
Alternatively, we would save the torch model (a pointer to memory) and have a serialization function to get the parameters out of the sessions memory location. We'll also need another function to create the torch model object i a new R session. This is probably faster in-session but is a general pain (unless there are some tools that I haven't seen).
The master
branch of this repository will soon be renamed to main
, as part of a coordinated change across several GitHub organizations (including, but not limited to: tidyverse, r-lib, tidymodels, and sol-eng). We anticipate this will happen by the end of September 2021.
That will be preceded by a release of the usethis package, which will gain some functionality around detecting and adapting to a renamed default branch. There will also be a blog post at the time of this master
--> main
change.
The purpose of this issue is to:
message id: euphoric_snowdog
Current using torch::torch_manual_seed(sample.int(10^5, 1))
but getting different results with the same R seed. If both of these are re-run, the same results occur, so we need to find a way to set the torch seed repeatedly.
library(lantern)
suppressPackageStartupMessages(library(tidymodels))
data(ames)
ames$Sale_Price <- log10(ames$Sale_Price)
set.seed(1)
torch_mlp(x = as.matrix(ames[, c("Longitude", "Latitude")]),
y = ames$Sale_Price, penalty = 0.10, epochs = 10)
#> Multilayer perceptron via torch
#>
#> relu activation
#> 2 features, 3 hidden units, 11 model coefficients
#> weight decay: 0.1
#> final validation RMSE after 10 epochs: 83.2622
set.seed(1)
torch_mlp(x = as.matrix(ames[, c("Longitude", "Latitude")]),
y = ames$Sale_Price, penalty = 0.10, epochs = 10)
#> Multilayer perceptron via torch
#>
#> relu activation
#> 2 features, 3 hidden units, 11 model coefficients
#> weight decay: 0.1
#> final validation RMSE after 10 epochs: 5.77386
Created on 2020-10-14 by the reprex package (v0.3.0)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. πππ
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google β€οΈ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.