Comments (4)
The distillation methods KD-Lib provides are designed primarily for classification tasks. Hence, the distiller
objects expect dataloaders which supply 2 things: the input data for the classification task and a corresponding label for the task. In your case, the dataloders seem to be supplying 3 things: input data, attn masks, and labels while only 2 are expected.
from kd_lib.
Hi @OriAlpha.
Could you tell me what kind of NLP task you are looking to do? Also, could you please post the error stack trace if possible?
from kd_lib.
Sorry i forgot to mention i am following distillation example on readme.
I am using SequenceClassification task, also the error was
for (data, label) in self.train_loader:
ValueError: too many values to unpack (expected 2)
i am sure custom dataloader may be creating issue
###tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
# Tokenize all of the sentences and map the tokens to thier word IDs.
input_ids = []
attention_masks = []
# For every sentence...
for sent in sentences:
# `encode_plus` will:
# (1) Tokenize the sentence.
# (2) Prepend the `[CLS]` token to the start.
# (3) Append the `[SEP]` token to the end.
# (4) Map tokens to their IDs.
# (5) Pad or truncate the sentence to `max_length`
# (6) Create attention masks for [PAD] tokens.
encoded_dict = tokenizer.encode_plus(
sent, # Sentence to encode.
add_special_tokens = True, # Add '[CLS]' and '[SEP]'
max_length = 100, # Pad & truncate all sentences.
pad_to_max_length = True,
return_attention_mask = True, # Construct attn. masks.
return_tensors = 'pt', # Return pytorch tensors.
)
# Add the encoded sentence to the list.
input_ids.append(encoded_dict['input_ids'])
# And its attention mask (simply differentiates padding from non-padding).
attention_masks.append(encoded_dict['attention_mask'])
# Convert the lists into tensors.
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(labels)
# Print sentence 0, now as a list of IDs.
print('Original: ', sentences[0])
print('Token IDs:', input_ids[0])
### Not combine the input id , mask and labels and divide the dataset
#:
from torch.utils.data import TensorDataset, random_split
# Combine the training inputs into a TensorDataset.
dataset = TensorDataset(input_ids, attention_masks, labels)
# Create a 90-10 train-validation split.
# Calculate the number of samples to include in each set.
train_size = int(0.90 * len(dataset))
val_size = len(dataset) - train_size
# Divide the dataset by randomly selecting samples.
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))
### Not you call loader of these datasets
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
# The DataLoader needs to know our batch size for training, so we specify it
# here. For fine-tuning BERT on a specific task, the authors recommend a batch
# size of 16 or 32.
batch_size = 32
# Create the DataLoaders for our training and validation sets.
# We'll take training samples in random order.
train_dataloader = DataLoader(
train_dataset, # The training samples.
sampler = RandomSampler(train_dataset), # Select batches randomly
batch_size = batch_size # Trains with this batch size.
)
# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader = DataLoader(
val_dataset, # The validation samples.
sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.
batch_size = batch_size # Evaluate with this batch size.
)
from kd_lib.
Feel free to close this if your issue has been resolved @OriAlpha.
from kd_lib.
Related Issues (20)
- import error HOT 14
- Benchmarking KD
- Benchmarking Pruning and Quantization
- Making a pipeline for Pruning, Quantization and Knowledge Distillation
- Pip install "stable" doesn't work HOT 3
- NameError: name 'best_student_id' is not defined HOT 1
- Implement Knowledge distillation by Functional Mapping HOT 1
- RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4 HOT 7
- Paper: Data-Distortion Guided Self-Distillation for Deep Neural Networks HOT 2
- Issue with CUDA HOT 2
- distillation of gelectra model
- Use mock data for unit tests
- Create 'main' branch and set it as default HOT 2
- Consider potential name change to 'kdlib' HOT 2
- Test BERT2LSTM with mock data
- Can I skip training the teacher network? HOT 1
- No module named 'KD_Lib.KD' HOT 8
- Is there a suitable speech enhancement ? HOT 1
- Relational KD
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from kd_lib.