During model training , Usually people will focus on model acceleration and improvement GPU Usage rate , But sometimes our time-consuming bottleneck is reading data ,gpu Processing too fast , Instead, cpu Hey, the data can't keep up . Of course, the framework will also provide some data reading acceleration schemes , such as tensorflow Of tf.data.TFRecordDataset,pytorch Of DataLoader Use num_workers Multi thread scheme is adopted in the parameter , There is also some code that makes all the data into a binary file and reads it into memory , Then quickly read data from memory , However, this scheme cannot handle big data projects .
tensorflow Of record Mr. Cheng is also needed record File format and then read ,pytorch Of DataLoader Set up num_workers Especially when windows Some versions of are set to non 0 There will be some problems , This article introduces the use of python A scheme of multithreading to process data , Then combine pytorch Of Dataset and DataLoader get data , For your reference .
Create a buffer class , Two locks are required to read and write data
import threading import random class Buffer: def __init__(self, size): self.size = size self.buffer = [] self.lock = threading.Lock() self.has_data = threading.Condition(self.lock) self.has_pos = threading.Condition(self.lock) def get_size(self): return self.size def get(self): with self.has_data: while len(self.buffer) == 0: self.has_data.wait() result = self.buffer[0] # print("get buffer size", len(self.buffer)) del self.buffer[0] self.has_pos.notify_all() return result def put(self, data): with self.has_pos: while len(self.buffer) >= self.size: self.has_pos.wait() self.buffer.append(data) self.has_data.notify_all() # test def get(): while True: get_data = buffer.get() # test def put(): while True: data = random.randint(0, 9) buffer.put(a)
buffer Class reference :https://cloud.tencent.com/developer/article/1724559
Generate a DataReader Create multithreaded write data , And single thread data reading . The following is the key code of multithreading
class DataReader: def __init__(self, max_buffer_size=5000): self.audio_files = files_to_list(training_files) random.shuffle(self.audio_files) self.buffer = Buffer(max_buffer_size) # Consumption data def comsume(self): while True: result = self.buffer.get() # The production data def produce(self): while True: global index index += 1 if index >= len(self.audio_files)-1: index = 0 start = time.time() file = self.audio_files[index] audio = load_wav(file) end = time.time() self.buffer.put(audio) def run_produce(self, thread_num=16): # Multithreaded production for _ in range(thread_num): th = threading.Thread(target=self.produce) th.start() def get_item(self, index): result = self.buffer.get() return result
Let's use a Dataset To use DataReader get data
class AudioDataset(torch.utils.data.Dataset): def __init__(self): self.data_reader = DataReader() self.data_reader.run_produce() def __getitem__(self, index): # from buffer Get a data from start = time.time() audio = self.data_reader.get_item(index) # Data processing ... audio = torch.from_numpy(audio).float() end = time.time() # print("get item time cost", (end - start) * 1000, audio.shape) return audio.unsqueeze(0) def __len__(self): return len(self.audio_files)
In the end, it can be passed DataLoader from DataSet In the loop to get batch Data input to the model for training
dataset = AudioDataset() dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, )