Giter Site home page Giter Site logo

Tsformer pretrain question about step HOT 5 CLOSED

Jimmy-7664 avatar Jimmy-7664 commented on May 26, 2024 2
Tsformer pretrain question

from step.

Comments (5)

zezhishao avatar zezhishao commented on May 26, 2024

Thanks for your question.
I have checked and this is a bug...
But this doesn't affect the final result because this metric is not used for anything.
Thanks again for reporting this bug, I'll fix it right away!

from step.

Jimmy-7664 avatar Jimmy-7664 commented on May 26, 2024

Thanks for your reply, so the right test method is the test method in the file "base_tsf_runner.py",` def test(self):
"""Evaluate the model.

    Args:
        train_epoch (int, optional): current epoch if in training process.
    """

    # test loop
    prediction = []
    real_value = []
    for _, data in enumerate(self.test_data_loader):
        forward_return = self.forward(data, epoch=None, iter_num=None, train=False)
        prediction.append(forward_return[0])        # preds = forward_return[0]
        real_value.append(forward_return[1])        # testy = forward_return[1]
    prediction = torch.cat(prediction, dim=0)
    real_value = torch.cat(real_value, dim=0)
    # re-scale data
    prediction = SCALER_REGISTRY.get(self.scaler["func"])(
        prediction, **self.scaler["args"])
    real_value = SCALER_REGISTRY.get(self.scaler["func"])(
        real_value, **self.scaler["args"])
    # summarize the results.
    # test performance of different horizon
    for i in self.evaluation_horizons:
        # For horizon i, only calculate the metrics **at that time** slice here.
        pred = prediction[:, i, :]
        real = real_value[:, i, :]
        # metrics
        metric_results = {}
        for metric_name, metric_func in self.metrics.items():
            metric_item = self.metric_forward(metric_func, [pred, real])
            metric_results[metric_name] = metric_item.item()
        log = "Evaluate best model on test data for horizon " + \
            "{:d}, Test MAE: {:.4f}, Test RMSE: {:.4f}, Test MAPE: {:.4f}"
        log = log.format(
            i+1, metric_results["MAE"], metric_results["RMSE"], metric_results["MAPE"])
        self.logger.info(log)
    # test performance overall
    for metric_name, metric_func in self.metrics.items():
        metric_item = self.metric_forward(metric_func, [prediction, real_value])
        self.update_epoch_meter("test_"+metric_name, metric_item.item())
        metric_results[metric_name] = metric_item.item()

`
I'm I right?

from step.

zezhishao avatar zezhishao commented on May 26, 2024

No, the test function in the base_tsf_runner is designed for the Time Series Forecasting (TSF) problem, which is not compatible with the reconstruction task in the pre-training stage.
Actually, I think you can fix this by adding a Tab for lines 84~90 like:

    @torch.no_grad()
    @master_only
    def test(self):
        """Evaluate the model.

        Args:
            train_epoch (int, optional): current epoch if in training process.
        """

        for _, data in enumerate(self.test_data_loader):
            forward_return = self.forward(data=data, epoch=None, iter_num=None, train=False)
            # re-scale data
            prediction_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[0], **self.scaler["args"])
            real_value_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[1], **self.scaler["args"])
            # metrics
            for metric_name, metric_func in self.metrics.items():
                metric_item = metric_func(prediction_rescaled, real_value_rescaled, null_val=self.null_val)
                self.update_epoch_meter("test_"+metric_name, metric_item.item())

from step.

Jimmy-7664 avatar Jimmy-7664 commented on May 26, 2024

Thanks for your answering, I got your idea. : )

from step.

zezhishao avatar zezhishao commented on May 26, 2024

This bug is now fixed. Thanks again for your report!

from step.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.