Giter Site home page Giter Site logo

Comments (5)

lewi0332 avatar lewi0332 commented on June 27, 2024 1

Hmm.

Sorry I didn't notice this before, but there is this in the delayed_saturated_mmm.py file on line 363:

    def _serializable_model_config(self) -> Dict[str, Any]:
        serializable_config = self.model_config.copy()
        if type(serializable_config["beta_channel"]["sigma"]) == np.ndarray:
            serializable_config["beta_channel"]["sigma"] = serializable_config[
                "beta_channel"
            ]["sigma"].tolist()
        return serializable_config

Looks like there is a process to convert the arrays that is only working on the beta_channel: sigma at the moment.

I added the others just now and it is working for me now. I can save and load. The model ID's match.

I'm a newb, but this is what I did that is working.

@property
    def _serializable_model_config(self) -> Dict[str, Any]:
        serializable_config = self.model_config.copy()
        for key in serializable_config:
            if isinstance(serializable_config[key], dict):
                for sub_key in serializable_config[key]:
                    if isinstance(serializable_config[key][sub_key], np.ndarray):
                        # Check if "dims" key to pass
                        if sub_key == "dims":
                            pass
                        # Convert all other  numpy arrays to lists
                        else:
                            serializable_config[key][sub_key] = serializable_config[key][sub_key].tolist()
        return serializable_config

from pymc-marketing.

michaelraczycki avatar michaelraczycki commented on June 27, 2024 1

This is something that needs to be patched up, I didn't know if in current model definition priors with lists of arguments would be applicable for all variables, so I did only conversions for those that were already implementing list variables. It needs to be addressed in the next minor patch, probably I can get the PR in this week

from pymc-marketing.

ricardoV94 avatar ricardoV94 commented on June 27, 2024

If you have any parameters defined as lists you should wrap them in numpy arrays.

Otherwise we may need to reproducible example to figure it out :)

from pymc-marketing.

lewi0332 avatar lewi0332 commented on June 27, 2024

Thanks @ricardoV94.

The parameters you mention in the model_config? Good catch. If run the model without any changes to model_config I can save then load without issue

However, I just tested using np.array() in the model_config and got an error while calling the fit() method that np.arrays are not json serializable.

dummy_model = DelayedSaturatedMMM(date_column = '', channel_columns= '', adstock_max_lag = 12)
model_config = dummy_model.default_model_config

#Model config default from .default_model_config
"""
model_config = {'intercept': {'mu': 0, 'sigma': 2},
 'beta_channel': {'sigma': 2, 'dims': ('channel',)},
 'alpha': {'alpha': 1, 'beta': 3, 'dims': ('channel',)},
 'lam': {'alpha': 3, 'beta': 1, 'dims': ('channel',)},
 'sigma': {'sigma': 2},
 'gamma_control': {'mu': 0, 'sigma': 2, 'dims': ('control',)},
 'mu': {'dims': ('date',)},
 'likelihood': {'dims': ('date',)},
 'gamma_fourier': {'mu': 0, 'b': 1, 'dims': 'fourier_mode'}}
"""

# Set Priors from params dataframe (bad params i know... just testing something)
model_config['beta_channel']['sigma'] = hyperparams[hyperparams['primary_variable'].isin(decay_channels)]['beta'].values
model_config['alpha']['alpha'] = np.array([3 for i in decay_channels])
model_config['alpha']['beta'] = (((1 / hyperparams[hyperparams['primary_variable'].isin(decay_channels)]['ads_alpha'].values) * 3) - 3)
model_config['lam']['alpha'] = np.array([3 for i in decay_channels])
model_config['lam']['beta'] = (((1 / hyperparams[hyperparams['primary_variable'].isin(decay_channels)]['sat_gamma'].values) * 3) - 3)

# model_config after set priors:
"""
{'intercept': {'mu': 0, 'sigma': 2},
 'beta_channel': {'sigma': array([0.4533017 , 0.25488063, 0.14992924, 0.14492646, 0.07828438]),
  'dims': ('channel',)},
 'alpha': {'alpha': array([3, 3, 3, 3, 3]),
  'beta': array([3.55001301, 2.87092431, 2.83535104, 2.76894977, 2.91873807]),
  'dims': ('channel',)},
 'lam': {'alpha': array([3, 3, 3, 3, 3]),
  'beta': array([4.12231653, 5.02896872, 5.42851249, 5.54761511, 5.7041018 ]),
  'dims': ('channel',)},
 'sigma': {'sigma': 2},
 'gamma_control': {'mu': 0, 'sigma': 2, 'dims': ('control',)},
 'mu': {'dims': ('date',)},
 'likelihood': {'dims': ('date',)},
 'gamma_fourier': {'mu': 0, 'b': 1, 'dims': 'fourier_mode'}}
"""

mmm = DelayedSaturatedMMM(
    model_config=model_config,
    date_column='date',
    channel_columns=decay_channels,
    control_columns=control_variables,
    adstock_max_lag=12,
    yearly_seasonality=2,
)

# ---------------------------------------------------------------------
# Fit Model
# ---------------------------------------------------------------------

mmm.fit(
    X=data.drop(target, axis=1),
    y=data[target],
    draws=SAMPLING_STEPS,
    tune=TUNNING_STEPS,
    target_accept=0.95,
    chains=4,
    random_seed=rng
    )

Here's the error I see when I have np.arrays in the model_config dict:

File [~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc_experimental/model_builder.py:341](https://file+.vscode-resource.vscode-cdn.net/home/derricklewis/Documents/Data%20Science/MMM_development/~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc_experimental/model_builder.py:341), in ModelBuilder.set_idata_attrs(self, idata)
    [339](file:///home/derricklewis/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc_experimental/model_builder.py?line=338) idata.attrs["version"] = self.version
    [340](file:///home/derricklewis/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc_experimental/model_builder.py?line=339) idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
--> [341](file:///home/derricklewis/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc_experimental/model_builder.py?line=340) idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
    [342](file:///home/derricklewis/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc_experimental/model_builder.py?line=341) # Only classes with non-dataset parameters will implement save_input_params
    [343](file:///home/derricklewis/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc_experimental/model_builder.py?line=342) if hasattr(self, "_save_input_params"):
...
--> [180](file:///home/derricklewis/miniconda3/envs/pymc/lib/python3.11/json/encoder.py?line=179)     raise TypeError(f'Object of type {o.__class__.__name__} '
    [181](file:///home/derricklewis/miniconda3/envs/pymc/lib/python3.11/json/encoder.py?line=180)                     f'is not JSON serializable')

TypeError: Object of type ndarray is not JSON serializable

Strangely. I can use an np.array in the 'beta_channel' : {'sigma': array([])} value and the model will mmm.fit() without an error, but if I use an array instead of a list in any of the other model_config values, I get the json error.

Even more strange... after I use the fit() method, the model_config dict has had the beta_channel['sigma'] converted to a list.

THanks for the response. I'll keep trying a few things.

from pymc-marketing.

ricardoV94 avatar ricardoV94 commented on June 27, 2024

Yes I think you can only use array for parameters? @michaelraczycki will perhaps spot the issue more quickly

from pymc-marketing.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.