Giter Site home page Giter Site logo

Article Figure 9 about figstep HOT 2 CLOSED

hbrachemi avatar hbrachemi commented on July 29, 2024
Article Figure 9

from figstep.

Comments (2)

YichenBC avatar YichenBC commented on July 29, 2024

Thank you for your interest for our work.
You can simply add output_hidden_states=True in generate function for outputing embeddings.

   result = model.generate(
        input_ids,
        images=images,
        do_sample=True if temperature > 0 else False,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        streamer=streamer,
        use_cache=True,
        top_p=top_p,
        stopping_criteria=[stopping_criteria],
        output_attentions=True,
        output_hidden_states=True,
        output_scores=True,
        return_dict_in_generate=True
    )

And you can use the following code to extract the embedding of the last token in last layer.
hidden_states = result.hidden_states hidden_states = query(text_prompt=query_text, image_prompt=query_image) embedding.append(hidden_states[0][-1][:,-1,:].cpu().numpy()[0])
So that you can get a semantic representation of one query and append it to the list of embeddings.
After generating several lists of embeddings, you can use t-SNE to produce such figures.
In our case, we save these embeddings to csv files and draw the figures with them. The code is put as follows

df = pd.read_csv("emb/harmful_150_embedding.csv", header=None)
harmful_150_embedding = df.to_numpy()

df = pd.read_csv("emb/harmless_150_embedding.csv", header=None)
harmless_150_embedding = df.to_numpy()

df = pd.read_csv("emb/figstep_harmful_150_embedding.csv", header=None)
figstep_harmful_150_embedding = df.to_numpy()

df = pd.read_csv("emb/figstep_harmless_150_embedding.csv", header=None)
figstep_harmless_150_embedding = df.to_numpy()


df = pd.read_csv("emb/harmless_mode2_150_embedding.csv", header=None)
harmless_mode2_embedding = df.to_numpy()

df = pd.read_csv("emb/harmful_mode2_150_embedding.csv", header=None)
harmful_mode2_embedding = df.to_numpy()

combined_data = np.vstack((harmful_150_embedding, harmless_150_embedding, figstep_harmful_150_embedding, figstep_harmless_150_embedding,harmful_mode2_embedding,harmless_mode2_embedding))

import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize = (8,5),dpi=300)
color = ["#39c5bb","tomato","#EE82EE","orange","blue","skyblue"]
marker = ["o","s","*","x","+","^"]
label = ["prohibited $\mathbb{Q}^{va}$","benign $\mathbb{Q}^{va}$", "prohibited $\mathsf{FigStep}$","benign $\mathsf{FigStep}$","prohibited $\mathbb{Q}_2'$","benign $\mathbb{Q}_2'$"]

vis = TSNE(n_components=2).fit_transform(np.array(combined_data))


for i in ([1,0,5,4,3,2]):
    scatter =ax.scatter(vis[150*i:150*(i+1),0],vis[150*i:150*(i+1),1],c=color[i],marker=marker[i],label=label[i])
plt.subplots_adjust(left=0.05,right=0.95)
plt.subplots_adjust(top=0.95,bottom=0.05)
ax.legend(loc='upper left',fontsize=14)
ax.axis('off')
plt.savefig("emb-llava-1.png",transparent=True)

I hope this will help you.

from figstep.

hbrachemi avatar hbrachemi commented on July 29, 2024

Hi again,
thanks a lot for the answer, this seems to work for me.

from figstep.

Related Issues (6)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.