Giter Site home page Giter Site logo

Comments (10)

jameswex avatar jameswex commented on August 22, 2024 2

Sorry for the lack of updates. The issue is that the attention_module rendering logic (https://github.com/PAIR-code/lit/blob/main/lit_nlp/client/modules/attention_module.ts#L109) assumes that due to the fixed width font that every char takes up the same fixed width in pixels, and places its lines based on that. But with chinese characters, the fixed width font renders them wider, so the math for placing the X position of the attention lines is wrong.

There are the correct number of attention lines, but they are squeezed into too small a space, and the text gets cut off incorrectly due to that, and the tokens that are shown don't line up with the lines they are meant for.

We'll work on fixing this. In the meantime, you could try changing the width setting on the line references above (and rebuild the client), and see if you can get the spacing to look correct for your use case. But we'll fix it so it works correctly regardless of language.

from lit.

jameswex avatar jameswex commented on August 22, 2024 1

To rebuild the client, see https://github.com/PAIR-code/lit/#download-and-installation, specifically the "yarn && yarn build" command.

from lit.

bigprince97 avatar bigprince97 commented on August 22, 2024

if input is English, is display rightly, but it can't display total chinese.

from lit.

jameswex avatar jameswex commented on August 22, 2024

Can you provide the link to your code / model / dataset so we can reproduce, if possible?

from lit.

bigprince97 avatar bigprince97 commented on August 22, 2024
class MyBertMLM(lit_model.Model):
  MASK_TOKEN = "[MASK]"
  @property
  def num_layers(self):
      return self.model.config.num_hidden_layers
  @property
  def max_seq_length(self):
    return self.model.config.max_position_embeddings
  def __init__(self, model_name="bert-base-chinese", top_k=10):
    super().__init__()
    self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    self.model = transformers.BertForMaskedLM.from_pretrained(
            model_name, output_hidden_states=True, output_attentions=True)
    self.top_k = top_k

  def _get_topk_tokens(self,
                       scores: np.ndarray) -> List[List[Tuple[str, float]]]:
    index_array = np.argpartition(scores, -self.top_k, axis=1)[:, -self.top_k:]
    top_tokens = [
        self.tokenizer.convert_ids_to_tokens(idxs) for idxs in index_array
    ]
    top_scores = np.take_along_axis(scores, index_array, axis=1)
    return [
        sorted(list(zip(toks, scores)), key=lambda ab: -ab[1])
        for toks, scores in zip(top_tokens, top_scores)
    ]
  def _postprocess(self, output: Dict[str, np.ndarray]):
    slicer = slice(1, output.pop("ntok") - 1)
    output["tokens"] = self.tokenizer.convert_ids_to_tokens(
        output.pop("input_ids")[slicer])
    probas = output.pop("probas")
    for i in range(len(range(self.num_layers))):
        output[f"layer_{i:d}_attention"] = output[f"layer_{i:d}_attention"][:, slicer, slicer]
    output["pred_tokens"] = self._get_topk_tokens(probas[slicer])
    for i, token in enumerate(output["tokens"]):
      if token != self.MASK_TOKEN:
        output["pred_tokens"][i] = []
    return output
  def max_minibatch_size(self, unused_config=None) -> int:
    return 8
  def predict_minibatch(self, inputs, config=None):
    tokenized_texts = [
        ex.get("tokens") or self.tokenizer.tokenize(ex["text"]) for ex in inputs
    ]
    encoded_input = self.tokenizer.batch_encode_plus(
        tokenized_texts,
        is_pretokenized=True,
        return_tensors="pt",
        add_special_tokens=True,
        max_length=self.max_seq_length,
        pad_to_max_length=True)
    max_tokens = torch.max(
        torch.sum(encoded_input["attention_mask"], dim=1))
    encoded_input = {k: v[:, :max_tokens] for k, v in encoded_input.items()}
    logits, embs, unused_attentions = self.model(**encoded_input)
    batched_outputs = {
        "probas": torch.softmax(logits, dim=-1).detach().numpy(),
        "input_ids": encoded_input["input_ids"].numpy(),
        "ntok": torch.sum(encoded_input["attention_mask"], dim=1).numpy(),
        "cls_emb": embs[-1][:, 0].detach().numpy(),  # last layer, first token
    }
    for i in range(len(unused_attentions)):
      batched_outputs[f"layer_{i:d}_attention"] = unused_attentions[i].detach().numpy()
    unbatched_outputs = utils.unbatch_preds(batched_outputs)
    return map(self._postprocess, unbatched_outputs)
  def input_spec(self):
    return {
        "text": lit_types.TextSegment(),
        "tokens": lit_types.Tokens(required=False),
    }
  def output_spec(self):
    spec = {
        "tokens": lit_types.Tokens(parent="text"),
        "pred_tokens": lit_types.TokenTopKPreds(align="tokens"),
        "cls_emb": lit_types.Embeddings(),
    }
    for i in range(self.num_layers):
      spec[f"layer_{i:d}_attention"] = lit_types.AttentionHeads(
          align=("tokens", "tokens"))
    return spec

I change the pretrained_lms.py,use pytorch chinese bert model,add attention to output spec.

if model_name.startswith("bert-"):
  models[model_name] = pretrained_lms.MyBertMLM(
    model_name_or_path, top_k=FLAGS.top_k)

in the pretrained_lm_demo.py, i use my model

image

image

it display well if input English , display not well if input Chinese, same as gpt2 model.

from lit.

bigprince97 avatar bigprince97 commented on August 22, 2024

and I attempt to change attention_module.ts, it don't work.

from lit.

jameswex avatar jameswex commented on August 22, 2024

We will reproduce this locally and work on a fix. Thanks for discovering the issue!

from lit.

2020zyc avatar 2020zyc commented on August 22, 2024

The same problem, any news? @bigprince97 @jameswex thanks

from lit.

2020zyc avatar 2020zyc commented on August 22, 2024

To rebuild the client, see https://github.com/PAIR-code/lit/#download-and-installation, specifically the "yarn && yarn build" command.

I changed the width and rebuild successfully. The attention graph changed with the different width, but still unnormal.

Look forward to your revision. Thanks.

from lit.

pratikchhapolika avatar pratikchhapolika commented on August 22, 2024

is align=("tokens", "tokens")) changed to this: lit_types.AttentionHeads(align_in="tokens", align_out="tokens") ?

from lit.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.