Giter Site home page Giter Site logo

Comments (2)

Tylman-M avatar Tylman-M commented on May 23, 2024

This is how I fixed it for my code. I highly suspect there's a better way to fix this, but in short three things needed to be fixed:

  1. Change 'end_token_id' to 'stop_token_ids'
  2. Wrap the token id argument in a list
  3. add '.to_tensor()' to the encoder input tokens'

(Note, I extended mine from Spanish only to French, so my encoders are named "fr_..." instead of "spa_")

def decode_sequences(input_sentences):
    batch_size = 1

    # Tokenize the encoder input.
    encoder_input_tokens = ops.convert_to_tensor(eng_tokenizer(input_sentences))
    if len(encoder_input_tokens[0]) < MAX_SEQUENCE_LENGTH:
        pads = ops.full((1, MAX_SEQUENCE_LENGTH - len(encoder_input_tokens[0])), 0)
        #encoder_input_tokens = ops.concatenate([encoder_input_tokens.to_tensor(), pads], 1) # <-- Original
        encoder_input_tokens = ops.concatenate([encoder_input_tokens.to_tensor(), pads], 1) # <-- Add ".to_tensor()" at the base of this

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def next(prompt, cache, index):
        logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]
        # Ignore hidden states for now; only needed for contrastive search.
        hidden_states = None
        return logits, hidden_states, cache

    # Build a prompt of length 40 with a start token and padding tokens.
    length = 40
    start = ops.full((batch_size, 1), fr_tokenizer.token_to_id("[START]"))
    pad = ops.full((batch_size, length - 1), fr_tokenizer.token_to_id("[PAD]"))
    prompt = ops.concatenate((start, pad), axis=-1)

    generated_tokens = nlp.samplers.GreedySampler()(
        next,
        prompt,
        # end_token_id = fr_tokenizer.token_to_id("[END]") #<-- Original
        stop_token_ids=[fr_tokenizer.token_to_id("[END]")], #<-- Change argument name and wrap in list
        index=1,  # Start sampling after start token.
    )
    generated_sentences = fr_tokenizer.detokenize(generated_tokens)
    return generated_sentences

from keras-io.

github-actions avatar github-actions commented on May 23, 2024

Are you satisfied with the resolution of your issue?
Yes
No

from keras-io.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.