Initial commit
This commit is contained in:
131
scripts/download-dataset.py
Normal file
131
scripts/download-dataset.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
import time
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# --- Configuration ---
|
||||
FINETUNING_DATASET = "HuggingFaceFW/fineweb"
|
||||
DATASET_SPLIT = "train"
|
||||
TOKENIZER_NAME = "HuggingFaceTB/SmolLM2-1.7B"
|
||||
TARGET_TOKENS = 6_000_000_000 # 6 Billion tokens
|
||||
OUTPUT_DIR = "../dataset"
|
||||
CHUNK_SIZE = 50_000 # Number of documents to collect before writing a batch
|
||||
|
||||
def download_and_save_in_chunks():
|
||||
"""Streams Fineweb, tokenizes, and saves to Parquet in memory-efficient chunks."""
|
||||
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
OUTPUT_FILE_PATH = os.path.join(OUTPUT_DIR, "data.parquet")
|
||||
|
||||
# 1. Load Tokenizer & Dataset in Streaming Mode
|
||||
print(f"Loading tokenizer: {TOKENIZER_NAME}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
|
||||
print(f"Loading dataset: {FINETUNING_DATASET} in streaming mode...")
|
||||
streaming_dataset = load_dataset(FINETUNING_DATASET, split=DATASET_SPLIT, streaming=True)
|
||||
|
||||
# 2. Define Schema (Crucial for ParquetWriter)
|
||||
# Based on Fineweb's fields: 'text', 'meta', 'id'
|
||||
schema = pa.schema([
|
||||
pa.field('text', pa.string()),
|
||||
pa.field('meta', pa.struct([
|
||||
pa.field('url', pa.string()),
|
||||
pa.field('dump', pa.string()),
|
||||
pa.field('s_cluster', pa.int64()),
|
||||
pa.field('token_count', pa.int64()), # Use their count, but rely on ours for stopping
|
||||
])),
|
||||
pa.field('id', pa.string())
|
||||
])
|
||||
|
||||
current_tokens = 0
|
||||
collected_batch = []
|
||||
|
||||
# Initialize Parquet Writer
|
||||
writer = None
|
||||
|
||||
print("\n--- Starting Stream and Chunked Write to Disk ---")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
for i, example in enumerate(streaming_dataset):
|
||||
# Tokenize and check limit
|
||||
tokens = len(tokenizer.encode(example['text']))
|
||||
|
||||
if current_tokens + tokens > TARGET_TOKENS:
|
||||
print("Token limit reached! Stopping stream.")
|
||||
break
|
||||
|
||||
collected_batch.append(example)
|
||||
current_tokens += tokens
|
||||
|
||||
# 3. Write Batch to Disk when CHUNK_SIZE is reached
|
||||
if (i + 1) % CHUNK_SIZE == 0:
|
||||
print(f"Writing batch of {CHUNK_SIZE:,} documents...")
|
||||
|
||||
# Convert list of dicts to PyArrow Table
|
||||
# We extract the 'meta' fields to match the schema structure
|
||||
meta_list = [d.pop('meta', {}) for d in collected_batch]
|
||||
|
||||
# Flatten the data structure for PyArrow Table creation
|
||||
flat_data = {
|
||||
'text': [d['text'] for d in collected_batch],
|
||||
'id': [d['id'] for d in collected_batch],
|
||||
# Recreate the structured meta column
|
||||
'meta': pa.array(meta_list, type=schema.field('meta').type)
|
||||
}
|
||||
|
||||
# Create PyArrow Table
|
||||
table = pa.Table.from_arrays(
|
||||
[pa.array(flat_data['text']), flat_data['meta'], pa.array(flat_data['id'])],
|
||||
schema=schema
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
# Initialize writer on first run
|
||||
writer = pq.ParquetWriter(OUTPUT_FILE_PATH, table.schema, compression='SNAPPY')
|
||||
|
||||
# Write the batch to disk
|
||||
writer.write_table(table)
|
||||
|
||||
# Clear the batch list to free memory
|
||||
collected_batch = []
|
||||
|
||||
print(f"Total documents written so far: {i + 1:,} | Total tokens: {current_tokens:,}")
|
||||
|
||||
finally:
|
||||
# 4. Final Write (any remaining documents) and Cleanup
|
||||
if collected_batch:
|
||||
# Handle the last, incomplete batch
|
||||
meta_list = [d.pop('meta', {}) for d in collected_batch]
|
||||
|
||||
flat_data = {
|
||||
'text': [d['text'] for d in collected_batch],
|
||||
'id': [d['id'] for d in collected_batch],
|
||||
'meta': pa.array(meta_list, type=schema.field('meta').type)
|
||||
}
|
||||
|
||||
table = pa.Table.from_arrays(
|
||||
[pa.array(flat_data['text']), flat_data['meta'], pa.array(flat_data['id'])],
|
||||
schema=schema
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
# Handle the edge case where the total is less than CHUNK_SIZE
|
||||
writer = pq.ParquetWriter(OUTPUT_FILE_PATH, table.schema, compression='SNAPPY')
|
||||
|
||||
writer.write_table(table)
|
||||
|
||||
if writer:
|
||||
writer.close()
|
||||
|
||||
end_time = time.time()
|
||||
print("\n--- Final Result ---")
|
||||
print(f"✅ Successfully created Parquet file: {os.path.abspath(OUTPUT_FILE_PATH)}")
|
||||
print(f"Final token count (approx): {current_tokens:,}")
|
||||
print(f"Total time for streaming and saving: {end_time - start_time:.2f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_and_save_in_chunks()
|
||||
Reference in New Issue
Block a user