Giter Site home page Giter Site logo

franneck94 / tensorcross Goto Github PK

View Code? Open in Web Editor NEW
14.0 5.0 3.0 1.54 MB

Cross Validation, Grid Search and Random Search for TensorFlow 2 Datasets

License: MIT License

Python 99.46% Makefile 0.54%
cross-validation grid-search tensorflow-datasets tensorflow2 random-search validation tensorflow gridsearchcv dataset python

tensorcross's Introduction

tensorcross's People

Contributors

franneck94 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Forkers

satellitesky

tensorcross's Issues

[BUG] Correct train-validation-split for the cross validation

System information

OS Platform and Distribution: All
TensorFlow version: 2.3
Python version: 3.8

Source code / logs

Correct the train-validation split of the BaseSearchCV's fit method (see here).
The validation set should be "slided" over the passed in dataset, such that every data point was once in the validation set.

Be sure to write tests for this new feature.

[REQ] Update examples in README.md

System information

OS Platform and Distribution: All
TensorFlow version: 2.3
Python version: 3.8

Source code / logs

The examples shown in the README.md can be approved.
First, one example for a cross-validated search can be added.
For this my recommendation is:

  • One example for GridSearch
  • One example for GridSearchCV
  • Remove the RandomSearch example from the README.md

[REQ] Remove Issues and Pull Request topic from README.md

System information

OS Platform and Distribution: All
TensorFlow version: 2.3
Python version: 3.8

Source code / logs

Remove the .github/PULL_REQUEST_TEMPLATE/ and the .github/ISSUE_TEMPLATE/ reference from the README.md.
These cant be parsed in the documentaion repository.

Mean is calculated over wrong axis [BUG]

In _run_search() -func in search_cv.py the code calculates the mean_val_scores as follows

mean_val_scores = np.mean(self.results_["val_scores"], axis=0)

This calculates mean over different model parameter combinations with the same test/validation set and thus returns the column which had with all combinations on average the smallest score. The mean should be calculated over scores within each combination.

To fix this bug one should calculate the mean_val_scores as follows:

mean_val_scores = np.mean(self.results_["val_scores"], axis=1)

Does not work with tensorflow-macos [BUG]

Pip install does not work with tensorflow-macos which is required for M1/M2 machines:
Screenshot 2023-03-11 at 18 34 17

Having installed the following packages:
tensorflow-macos 2.9.0
tensorflow-metal 0.5.1

[BUG] TensorFlow warning with creating multiple models of the same function

System information

OS Platform and Distribution: All
TensorFlow version: 2.3
Python version: 3.8

Source code / logs

See example: here
The code will raise a TensorFlow warning:

WARNING:tensorflow:5 out of the last 15 calls to <function Model.make_train_function.<locals>.train_function at 0x0000024F785DBC10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

We have to identify how we can remove only this warning.
If there is no such option, we could deactivate warnings before running the search and re-activate the warnings afterwards.

[REQ] Run TensorBoard callback for each parameter combination

System information

OS Platform and Distribution: All
TensorFlow version: 2.3
Python version: 3.8

Source code / logs

If the user wants to pass a TensorBoard callback to the .fit method of the Grid/Random search, the callback would store the
logs into the same directory.
It would be better to create a sub-directory for each parameter combination.

Program returns opposite of 'best model' [BUG]

In _run_search() -func in search_cv.py as you can see the code determines that the best model is the one which has the largest mean of scores:

mean_val_scores = np.mean(self.results_["val_scores"], axis=0)
best_run_idx = np.argmax(mean_val_scores)

Essentially this returns wrong model, if somekind of loss is used in the calculation of val_scores (because with loss smaller is bettter). This wouldn't be a problem (it would be a feature) if it wasn't stated in the documentation that

The grid search is evaluated by:

  • The validation loss value, if no metrics are passed to model.compile()
  • The validation score of the last defined metric in model.compile()

This kind of implies that you can use either score or loss func and you still get the best model which isn't the case.

Best way to fix this would be to add a 'mode' argument to the function/method. This parameter would determine which is desirable; maximizing or minimizing the score.

[REQ] Example for TensorBoard usage

System information

OS Platform and Distribution: All
TensorFlow version: 2.0+
Python version: 3.7+

Source code / logs

In Release 0.2.0 the creation of sub folders for a TensorBoard callback was added.
Add an example for the GridSearchCV in combination of a TensorBoard callback.
Releated Issue: #31

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.