Giter Site home page Giter Site logo

Comments (2)

mikemahoney218 avatar mikemahoney218 commented on June 29, 2024 1

Howdy @jamesgrecian -- sorry to miss this originally, this repo doesn't see a ton of activity.

If I understand your question correctly: you can pass workflows and explainers to terra::predict() pretty easily. Depending on your workflow, you might need to write a wrapper function (for the fun argument) that extracts a vector of predictions from the workflow output, or use the index argument to subset the returned data frame:

## Packages and other setup:
set.seed(123)

library(sf)
#> Linking to GEOS 3.11.1, GDAL 3.6.2, PROJ 9.1.1; sf_use_s2() is TRUE
library(terra)
#> terra 1.7.29
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.3).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
library(tidymodels)

## Data prep:
# pak::pkg_install("Nowosad/spDataLarge")
data("lsl", "study_mask", package = "spDataLarge")
ta <- terra::rast(system.file("raster/ta.tif", package = "spDataLarge"))
lsl <- lsl |> 
  st_as_sf(coords = c("x", "y"), crs = "EPSG:32717")

## Fit some workflow 
glm_model <- logistic_reg() |> 
  set_engine("glm") |> 
  set_mode("classification")

glm_wflow <- workflow() |> 
  add_formula(lslpts ~ slope + cplan + cprof + elev + log10_carea) |> 
  add_model(glm_model) |> 
  fit(lsl)

glm_explainer <- explain_tidymodels(
  glm_wflow, 
  data = lsl, 
  y = as.logical(as.character(lsl$lslpts))
)
#> Preparation of a new explainer is initiated
#>   -> model label       :  workflow  (  default  )
#>   -> data              :  350  rows  7  cols 
#>   -> target variable   :  350  values 
#>   -> predict function  :  yhat.workflow  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package tidymodels , ver. 1.0.0 , task classification (  default  ) 
#>   -> model_info        :  Model info detected classification task but 'y' is a logical . Converted to numeric.  (  NOTE  )
#>   -> predicted values  :  numerical, min =  0.00233623 , mean =  0.5 , max =  0.9858769  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -0.9858769 , mean =  -7.765776e-10 , max =  0.9780077  
#>   A new explainer has been created!

## Predicting is easy via terra:
terra::predict(ta, glm_explainer) |> plot()

## Though you might need to wrap the prediction function 
## to guarantee you get a vector of predictions back:
try(terra::predict(ta, glm_wflow))
#> Error in out[[i]] <- data.frame(value = 1:length(out[[i]]), label = out[[i]]) : 
#>   more elements supplied than there are to replace
terra::predict(
  ta, 
  glm_wflow, 
  fun = \(model, object) predict(model, object)$.pred_class
) |> 
  plot()

# Or alternatively, use the `index` argument to select which columns from the output should be rasterized. Note that we lost the data type here, though, because `glm_wflow` was returning factors which got auto-converted into integers:
terra::predict(
  ta, 
  glm_wflow,
  index = 1
) |> plot()

Created on 2023-04-26 with reprex v2.0.2

Does that answer what you're asking?

from planning.

jamesgrecian avatar jamesgrecian commented on June 29, 2024

Thanks @mikemahoney218, this is great!

I was getting confused before because when using explain_tidymodels and model_profile to generate a partial dependence plot the outputs will differ with each call to the functions depending on the number of samples used.

I'm used to simply getting the mean prediction from a model prediction object. Based on the DALEX documentation it wasn't clear what to do in this case.

Thanks again

from planning.

Related Issues (19)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.