Source code for ls_mlkit.dataset.regular_language

import exrex
import torch
from torch.utils.data import DataLoader, Dataset, random_split


[docs] class RegularLanguageDataset(Dataset): def __init__(self, regex_pattern, max_len=10, data_size=100, limit=100): self.regex_pattern = regex_pattern generator = exrex.generate(self.regex_pattern, limit=limit) self.data = list() for s in generator: if len(s) <= max_len: self.data.append(s) if len(self.data) > data_size: break def __len__(self): return len(self.data) def __getitem__(self, idx): string = self.data[idx] return string
[docs] def get_regular_language_dataset(regex_pattern, max_len=10, data_size=100, limit=100, test_ratio=0.2, **kwargs): dataset = RegularLanguageDataset(regex_pattern, max_len=max_len, data_size=data_size, limit=limit) num_samples = len(dataset) train_size = int((1 - test_ratio) * num_samples) test_size = num_samples - train_size train_dataset, test_dataset = random_split( dataset, [train_size, test_size], generator=torch.Generator().manual_seed(0) ) return train_dataset, test_dataset, test_dataset
if __name__ == "__main__": regex_pattern = r"a*" max_len = 20 dataset_size = 100 limit = 100 dataset = RegularLanguageDataset(regex_pattern, max_len=max_len, data_size=dataset_size, limit=limit) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) for batch in dataloader: print(batch)