Comments (11)
Hi,
to define a custom training loop you can create an instance of the Trainer
class defined al line 273. Inside your training loop you can then call the member function train_step
defined at line 593.
To use a custom loss function you can give a look at how I defined some of them. As an example you can look at: MeanSquaredError
line 28 or CategoricalCrossentropy
line 80. There are more in the same file that you can look at.
So the new loss function you define needs to have a call
method that implements your loss and a residual
which is used inside the LM optimiser.
The residual has to be defined so that:
loss = mean(residual^2)
It does not need to be literally like that, you can implement a more stable and computationally efficient expression as long as the final results is the one above.
Let me know if it helps.
from tf-levenberg-marquardt.
Since w1 is a fixed constant you do not need to pass it as a parameter of the function, but you can save it as a member variable of your custom loss class.
class CustomLoss(tf.keras.losses.Loss):
def __init__(self,
w1,
reduction=tf.keras.losses.Reduction.AUTO,
name='custom_loss'):
super(CustomLoss, self).__init__(
reduction=reduction,
name=name)
self.w1 = w1
def call(self, y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
loss1 = mse(y_pred , y_true)
loss2 = custom_loss(y_pred , self.w1)
loss_tol = loss1 + loss2
return loss_tol
def residuals(self, y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
# you have to write here the code according to how you custom loss is defined
# so that: loss = mean(residuals^2)
from tf-levenberg-marquardt.
Thank you for your patience indication!! It is very helpful!
I will try to realize it.
from tf-levenberg-marquardt.
Thank you for your answer!! It is very helpful to me. To be honest, I am a beginner in the python language.
The situation is I have written a custom training loop according to the tensorflow guide,
Then I want to custom the loss function, namely, the input variables of the loss function are not only Y_pred and Y_true, there are some fixed constants(not need to be differentiated) that are used to calculate the corresponding loss.
My problem is, in my opinion(maybe is wrong),
- we cannot pass two or more loss functions (including MSE loss and custom loss) to model.compile and/or model_wrapper.compile
- according to your code, we can only pass the train_dataset to model.fit and/or model_wrapper.fit, but how to pass other parameters to it? just like fixed constants?
Because, in my original code, I need to pass X, Y_true, and fixed constants to the training loop, then two losses will be obtained, one is the MSE loss based on Y_true and Y_pred, another is the custom loss based on Y_pred and fixed constants, then the gradient descend will be conducted to get the gradient, finally, Adam is used to minimizing the loss. My goal is to replace the Adam optimizer with your LM optimizer so that I can obtain a much better result (You know the common gradient descent method cannot achieve the global optima). The idea of my code is shown as follows:
def train_step(x,y,w1):
with tf.GradientTape() as tape:
y_pred = model(x)
loss1 = mse(y_pred , y)
loss2 = custom_loss(y_pred , w1) **# Note: w1 is a fixed constant**
loss_tol= loss1+loss2
gradients = tape.gradient(loss_tol, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables)) **# Here, it needs to be replaced with LM optimizer**
return loss_tol
So, I do hope you can give me some suggestions to realize this goal. Thank you again!!
from tf-levenberg-marquardt.
No problem. Let me know how it goes.
from tf-levenberg-marquardt.
Hi, Fabio,
Following your suggestions, today I try to realize it, I have mastered how to define the custom loss function and use the Class Trainer, but I find a new problem. Yesterday, in my problem statements, I forgot there is an input variable in my custom loss function, namely, the custom loss not only relies on the Y_true, Y_pred, and some fixed constants but also needs the input variable X_train.
So, I think the method of model.compile / model_wrapper.compile cannot achieve my goal, because such method seems to determine the loss before the model training, but my goal is to introduce X_train to calculate the custom loss during the training process, meanwhile, the loss using in compile method inherits from tf.keras.losses which only accept output variables (Y_true and Y_pred).
What about your advice?
Hope for your response ASAP.
Thanks a lot !!~~
from tf-levenberg-marquardt.
A simple workaround is to include X_train
in your Y
data. So that Y_true
is something like tf.stack([X_train, Y_train], axis=-1)
or similar. I do not know the dimensionality of your data.
BTW: I do not think you need to use a custom train loop. I think it is easier for you to just use the ModelWrapper.
from tf-levenberg-marquardt.
Uhmm..., in my task, the data dimensions of X_train, and Y_train are not the same, and at first, I need to normalize the input and label, then during the training process, namely the calculation of custom loss, I need to inverse-normalize the input, label and predicted label. So the operation could be complex. Maybe I need to study the ModelWrapper carefully.
Anyway, thanks for your advice, I will try it, I hope I can bring good news to you.
from tf-levenberg-marquardt.
If X and Y have a different dimension then you can have Y_train to be a list or a tuple of tensors (X, Y).
In order to use the ModelWrapper, I would consider trying to place all the extra operation that you need to do during training inside the CustomLoss or inside the Model itself.
In that way you can use all the callback provided by keras to save chekpoints, logging etc.
from tf-levenberg-marquardt.
Hi, Fabio, I am trying to realize the code. By the way, there are two questions I want to discuss with you.
The first one is, except to include X_train in Y data (as you mentioned above), how to pass other variables to model.fit ? it seems only accept two parameters, X and Y.
The second one is, I know the LM algorithm is used in the neural network toolbox of MATLAB in the early stage, in those years, we don't develop the technique of deep learning or deep neural network. Nowadays, we usually adopt deep networks, by google search, I know some people say the LM algorithm is useful for the shallow network with few neurons, but for the deep network with numerous trainable parameters, the algorithm has poor performance. I also see your test example is a simple nonlinear fitting, so, are you testing this algorithm with a more complex situation, such as image recognition or natural language process?
I also find an issue about using the algorithm in PINN, which is an interesting topic, but some physical problems or engineering problems could need a deep network, for such a situation, whether LM algorithm can get a better result than gradient descent?
Thanks.
from tf-levenberg-marquardt.
-
No
model.fit
only accept X and Y, but since your loss is custom you can include into Y all the data you want even a dictionary so that you can access all the data you need to compute the loss (that is the way to do it). -
The LM algorithm is mainly used to train shallow network as they have a small number of parameters.
The problem with LM is that the computational complexity is N^3 with respect to the number of parameters or the batch size depending on which one of the two is larger. You can find the details about it in the Memory, Speed and Convergence considerations section in the readme of the my repo.
You can tradeoff the complexity of LM by reducing the batch size, but the more you reduce it the more you are going towards a simple stochastic gradient descent and hence losing the advantages of LM.
So even if LM has better per step convergence rate (and sometimes can also converge to better training losses), the computational speed of computing each step for large models is the main reason gradient descent is still the standard for training large neural networks.
from tf-levenberg-marquardt.
Related Issues (20)
- Getting a shape error while trying to fit another dataset HOT 4
- how can I use model_wrapper to test the model and get the predicted value? HOT 7
- Getting error when trying to wrap a model with a tf keras Normalization layer HOT 4
- Issue with a model that returns the gradient of a sequence HOT 4
- damping method and matrix solver HOT 3
- Random results HOT 1
- How to save this model and load weights? HOT 12
- Retracing warning on latest tensorflow version HOT 1
- Loss function returns 0 after first epoch for training set only when using validation data in training HOT 4
- Error when running the code test_curve_fitting.py HOT 4
- TypeError when trying to train model HOT 1
- Applying the LM optimizer for PINNs HOT 13
- Hyperparameter tuning to avoid overfitting HOT 1
- Input matrix is not invertible HOT 2
- Combine fireTS library for NARX network with Levenberg Marquardt HOT 6
- Return value for ModelWrapper fit() HOT 1
- Need help HOT 4
- Applying Levenberg-Marquardt to physically informed neural networks (PINNs) HOT 9
- Error in resuduals when labels given as int instead of float64 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tf-levenberg-marquardt.