Comments (8)
Thank you, that's very kind of you!
The code will come soon. I will clean up our code after I finish some other jobs.
from audiossl.
@lmaxwell Thanks for adding ATST Frame codes.
I'm trying to follow: https://github.com/Audio-WestlakeU/audiossl/blob/main/audiossl/methods/atstframe/README.md
However, I guess the checkpoint weights have model weights only, and the following doesn't work.
model = load_model("audiotransformer_base_mAP_4999.pt")
@daisukelab Hi, sorry for the late response. I have rechecked the model loading , and it worked fine. I have downloaded the checkpoint on google drive and run the following code. What problem did you encounter?
from audiossl.methods.atstframe.embedding import load_model,get_scene_embedding,get_timestamp_embedding
import torch
model = load_model("./atstframe_base.ckpt")
audio = torch.randn(1,20000) # Input audio can be of shape [1,N] or [B,1,N]
"""
extract scene (clip-level) embedding from an audio clip
=======================================
args:
audio: torch.tensor in the shape of [1,N] or [B,1,N]
model: the pretrained encoder returned by load_model
return:
emb: retured embedding in the shape of [1,N_BLOCKS*emb_size] or [B,1,N_BLOCKS*emb_size], where emb_size is 768 for base model and 384 for small model.
"""
emb_scene = get_scene_embedding(audio,model)
"""
Extract frame-level embeddings from an audio clip
==================================================
args:
audio: torch.tensor in the shape of [1,N] or [B,1,N]
model: the pretrained encoder returned by load_model
return:
emb: retured embedding in the shape of [1,T,N_BLOCKS*emb_size] or [B,1,T,N_BLOCKS,emb_size], where emb_size is 768 for base model and 384 for small model.
timestamps: timestamps in miliseconds
"""
emb_timestamp,t = get_timestamp_embedding(audio,model)
from audiossl.
Hi, @daisukelab. Thank you for correcting me.
- The get_scene_embedding() returns [B, D], where B and D are Batch and feature Dimension; it would be
[B,1,N_BLOCKS*emb_size]
according to the README.audio = torch.randn(16, 1, 20000) emb_scene = get_scene_embedding(audio,model) emb_scene.shape # output: torch.Size([16, 9216]) -> This should be [16, 1, 9216] according to "[B,1,N_BLOCKS*emb_size]".
The README has beed fixed in the latest commit.
- The get_timestamp_embedding() seems to have a problem as the error below.
audio = torch.randn(1, 20000) emb_timestamp,t = get_timestamp_embedding(audio,model) --------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[11], line 2 1 audio = torch.randn(1, 20000) ----> 2 emb_timestamp,t = get_timestamp_embedding(audio,model) File audiossl/audiossl/methods/atstframe/embedding.py:121, in get_timestamp_embedding(audio, model) 118 mel_chunk=mel[:,:,:,start:end] 119 len_chunk = torch.tensor([mel_chunk.shape[-1]]).expand(mel.shape[0]).to(audio.device) --> 121 output_chunk = model.get_intermediate_layers(mel_chunk,len_chunk,n=N_BLOCKS,scene=False) 123 output.append(output_chunk) 124 output=torch.cat(output,dim=1) File audiossl/audiossl/methods/atstframe/audio_transformer.py:279, in FrameAST.get_intermediate_layers(self, x, length, n, scene) 277 output.append(torch.mean(x[:,:self.nprompt],dim=1)) 278 else: --> 279 output.append(norm_x(x[:,self.nprompt:])) 281 return torch.cat(output,dim=-1) TypeError: 'Tensor' object is not callable
This issure was solved by commit 7378690. Please pull the latest commit.
from audiossl.
@daisukelab Thank you for your contribution. It's awesome and looks great.
@lmaxwell FYI -- We added an ATST wrapper for our evaluation package for audio representations.
- https://github.com/nttcslab/eval-audio-repr
- https://github.com/nttcslab/eval-audio-repr/blob/main/evar/ar_atst.py
- https://github.com/nttcslab/eval-audio-repr/blob/main/evar/ar_atst_frame.py
Please let me know if anything wrong with ours. Thank you.
from audiossl.
@lmaxwell Thanks for adding ATST Frame codes.
I'm trying to follow:
https://github.com/Audio-WestlakeU/audiossl/blob/main/audiossl/methods/atstframe/README.md
However, I guess the checkpoint weights have model weights only, and the following doesn't work.
model = load_model("audiotransformer_base_mAP_4999.pt")
I'm trying to implement the ATST Frame wrapper for our evaluation environment (https://github.com/nttcslab/eval-audio-repr).
We will be happy if any working example using the weight file is available.
Thanks in advance.
BTW, I have forgotten to upload the ATST (Clip) wrapper for the evaluator (https://github.com/nttcslab/eval-audio-repr), I will also do that...
from audiossl.
@lmaxwell Hi, I appreciate your support. And excuse me, it was basically my mistake. I was using the wrong checkpoint weight.
Then, I could try the get_scene_embedding example first, but I also found two specific issues.
- The get_scene_embedding() returns [B, D], where B and D are Batch and feature Dimension; it would be
[B,1,N_BLOCKS*emb_size]
according to the README.
audio = torch.randn(16, 1, 20000)
emb_scene = get_scene_embedding(audio,model)
emb_scene.shape # output: torch.Size([16, 9216]) -> This should be [16, 1, 9216] according to "[B,1,N_BLOCKS*emb_size]".
- The get_timestamp_embedding() seems to have a problem as the error below.
audio = torch.randn(1, 20000)
emb_timestamp,t = get_timestamp_embedding(audio,model)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[11], line 2
1 audio = torch.randn(1, 20000)
----> 2 emb_timestamp,t = get_timestamp_embedding(audio,model)
File audiossl/audiossl/methods/atstframe/embedding.py:121, in get_timestamp_embedding(audio, model)
118 mel_chunk=mel[:,:,:,start:end]
119 len_chunk = torch.tensor([mel_chunk.shape[-1]]).expand(mel.shape[0]).to(audio.device)
--> 121 output_chunk = model.get_intermediate_layers(mel_chunk,len_chunk,n=N_BLOCKS,scene=False)
123 output.append(output_chunk)
124 output=torch.cat(output,dim=1)
File audiossl/audiossl/methods/atstframe/audio_transformer.py:279, in FrameAST.get_intermediate_layers(self, x, length, n, scene)
277 output.append(torch.mean(x[:,:self.nprompt],dim=1))
278 else:
--> 279 output.append(norm_x(x[:,self.nprompt:]))
281 return torch.cat(output,dim=-1)
TypeError: 'Tensor' object is not callable
from audiossl.
@lmaxwell Hi, I confirmed the get_timestamp_embedding()
is working fine now. That'd make everybody happy. :)
I think I should close this.
Thank you very much again!
from audiossl.
@lmaxwell FYI -- We added an ATST wrapper for our evaluation package for audio representations.
- https://github.com/nttcslab/eval-audio-repr
- https://github.com/nttcslab/eval-audio-repr/blob/main/evar/ar_atst.py
- https://github.com/nttcslab/eval-audio-repr/blob/main/evar/ar_atst_frame.py
Please let me know if anything wrong with ours.
Thank you.
from audiossl.
Related Issues (8)
- questions about atst-frame evaluated on audioset_strong_eval HOT 3
- Share the training log HOT 6
- Figure background can be black in a dark mode browser HOT 1
- not found 'assl' package. HOT 2
- ATST: Could you share finetuning details, please? HOT 2
- Extracting embeddings. HOT 2
- sorry, I want to know how to download the audioset? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from audiossl.