Giter Site home page Giter Site logo

Comments (10)

daandouwe avatar daandouwe commented on August 20, 2024

This is interesting, thank you. I believe we have never trained on just source tags, so that must be why we have never encountered this error before. Thank you for reporting it.

Let me just setup a little experiment to try and reproduce this, and will report to you when I understand how this could be fixed.

from openkiwi.

daandouwe avatar daandouwe commented on August 20, 2024

I just tried to train Bert and XLM-R based models on GPU with predicting just source tags, and did not encounter this problem. I used

trainer:
    main_metric:
        - source_tags_MCC

and

system:
    model:
        outputs:
            word_level:
                target: false
                gaps: false
                source: true

Could you share (the relevant parts of) your config? For example, what is main_metric that you are using?

from openkiwi.

iamhere1 avatar iamhere1 commented on August 20, 2024

Thank you for your help, I use source_tags_F1_MULT and source_tags_CORRECT as my main metric, and the sentence level tag is not predicted in my experiment. The relevant parts in my config are as following.

outputs:
        ####################################################
        # Output options configure the downstream tasks the
        #  model will be trained on by adding specific layers
        #  responsible for transforming decoder features into
        #  predictions.
        word_level:
            target: false
            gaps: false
            source: true
            class_weights:
                target_tags:
                    BAD: 3.0
                gap_tags:
                    BAD: 5.0
                source_tags:
                    BAD: 3.0
        sentence_level:
            hter: false
            use_distribution: false
            binary: false
        n_layers_output: 2
        sentence_loss_weight: 1

and

main_metric:
    - source_tags_F1_MULT
    - source_tags_CORRECT

from openkiwi.

iamhere1 avatar iamhere1 commented on August 20, 2024

Hi, @daandouwe
I tried to replace the code in callbacks.py like this, and it seems worked now

    # mode_dict = {
    #     'min': np.less,
    #     'max': np.greater,
    #     'auto': np.greater if 'acc' in self.monitor else np.less,
    # }
    mode_dict = {
        'min': torch.lt,
        'max': torch.gt,
        'auto': torch.gt if 'acc' in self.monitor else torch.lt,
    }

from openkiwi.

daandouwe avatar daandouwe commented on August 20, 2024

Thanks!

Turns out that the problem is caused by the fact that not all the metric values have been moved to CPU.

Inspecting the values in the dictionary metrics in https://github.com/Unbabel/OpenKiwi/blob/master/kiwi/training/callbacks.py#L94 tells us the following:

{'F1_BAD': tensor(0.6000),
 'F1_OK': tensor(0.6092),
 'loss': tensor(205.8955, device='cuda:0'),
 'metrics': {'F1_BAD': tensor(0.6000),
             'F1_OK': tensor(0.6092),
             'source_tags_CORRECT': tensor(0.6047, device='cuda:0'),
             'source_tags_F1_MULT': 0.36551724137931035,
             'source_tags_F1_MULT+source_tags_CORRECT': tensor(0.9702, device='cuda:0'),
             'source_tags_MCC': tensor(0.2713, dtype=torch.float64)},
 'source_tags_CORRECT': tensor(0.6047, device='cuda:0'),
 'source_tags_F1_MULT': 0.36551724137931035,
 'source_tags_F1_MULT+source_tags_CORRECT': tensor(0.9702, device='cuda:0'),
 'source_tags_MCC': tensor(0.2713, dtype=torch.float64),
 'val_F1_BAD': tensor(0.4514),
 'val_F1_OK': tensor(0.6182),
 'val_loss': tensor(141.6612, device='cuda:0'),
 'val_loss_source_tags': tensor(141.6612, device='cuda:0'),
 'val_source_tags_CORRECT': tensor(0.5497, device='cuda:0'),
 'val_source_tags_F1_MULT': 0.27903935726135765,
 'val_source_tags_F1_MULT+source_tags_CORRECT': tensor(0.8288, device='cuda:0'),
 'val_source_tags_MCC': tensor(0.1847, dtype=torch.float64)}

This is also the reason that main metric source_tags_MCC works but source_tags_CORRECT does not. (I believe this flew under the radar because it seems to be the case for the lesser-used metrics).

We will need to solve this by making sure all the metrics return torch tensors that have been moved to CPU. Preferably, all the values returned by metrics will actually just be python floats.

We will try and fix this in a PR.

from openkiwi.

iamhere1 avatar iamhere1 commented on August 20, 2024

Yes, Thank you for your help! And to replace the fuctions of numpy with the operation of pytorch, maybe another solution?

from openkiwi.

iamhere1 avatar iamhere1 commented on August 20, 2024

Thanks!

Turns out that the problem is caused by the fact that not all the metric values have been moved to CPU.

Inspecting the values in the dictionary metrics in https://github.com/Unbabel/OpenKiwi/blob/master/kiwi/training/callbacks.py#L94 tells us the following:

{'F1_BAD': tensor(0.6000),
 'F1_OK': tensor(0.6092),
 'loss': tensor(205.8955, device='cuda:0'),
 'metrics': {'F1_BAD': tensor(0.6000),
             'F1_OK': tensor(0.6092),
             'source_tags_CORRECT': tensor(0.6047, device='cuda:0'),
             'source_tags_F1_MULT': 0.36551724137931035,
             'source_tags_F1_MULT+source_tags_CORRECT': tensor(0.9702, device='cuda:0'),
             'source_tags_MCC': tensor(0.2713, dtype=torch.float64)},
 'source_tags_CORRECT': tensor(0.6047, device='cuda:0'),
 'source_tags_F1_MULT': 0.36551724137931035,
 'source_tags_F1_MULT+source_tags_CORRECT': tensor(0.9702, device='cuda:0'),
 'source_tags_MCC': tensor(0.2713, dtype=torch.float64),
 'val_F1_BAD': tensor(0.4514),
 'val_F1_OK': tensor(0.6182),
 'val_loss': tensor(141.6612, device='cuda:0'),
 'val_loss_source_tags': tensor(141.6612, device='cuda:0'),
 'val_source_tags_CORRECT': tensor(0.5497, device='cuda:0'),
 'val_source_tags_F1_MULT': 0.27903935726135765,
 'val_source_tags_F1_MULT+source_tags_CORRECT': tensor(0.8288, device='cuda:0'),
 'val_source_tags_MCC': tensor(0.1847, dtype=torch.float64)}

This is also the reason that main metric source_tags_MCC works but source_tags_CORRECT does not. (I believe this flew under the radar because it seems to be the case for the lesser-used metrics).

We will need to solve this by making sure all the metrics return torch tensors that have been moved to CPU. Preferably, all the values returned by metrics will actually just be python floats.

We will try and fix this in a PR.

Yeah, that's more reasonable. Thank you!

from openkiwi.

daandouwe avatar daandouwe commented on August 20, 2024

maybe another solution?

You could edit

        current = metrics.get(self.monitor)
        if self.monitor_op(current - self.min_delta, self.best):

to

        current = metrics.get(self.monitor)
        if not isinstance(current, float):
            current = current.cpu().item()
        if self.monitor_op(current - self.min_delta, self.best):

in https://github.com/Unbabel/OpenKiwi/blob/master/kiwi/training/callbacks.py#L96 (in your local kiwi path).

I tried this and it worked, but of course it's not as nice as dealing with the problem at the root ;). We'll keep you updated on that.

from openkiwi.

iamhere1 avatar iamhere1 commented on August 20, 2024

maybe another solution?

You could edit

        current = metrics.get(self.monitor)
        if self.monitor_op(current - self.min_delta, self.best):

to

        current = metrics.get(self.monitor)
        if not isinstance(current, float):
            current = current.cpu().item()
        if self.monitor_op(current - self.min_delta, self.best):

in https://github.com/Unbabel/OpenKiwi/blob/master/kiwi/training/callbacks.py#L96 (in your local kiwi path).

I tried this and it worked, but of course it's not at nice and dealing with the problem at the root ;). We'll keep you updated on that.

OK, thanks for your help, now it's woking.

from openkiwi.

captainvera avatar captainvera commented on August 20, 2024

Hey @iamhere1,

I'm glad it's working, closing this issue due to inactivity.

from openkiwi.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.