132 lines
4.9 KiB
Python
132 lines
4.9 KiB
Python
import os
|
|
import time
|
|
import pyarrow as pa
|
|
import pyarrow.parquet as pq
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer
|
|
|
|
# --- Configuration ---
|
|
FINETUNING_DATASET = "HuggingFaceFW/fineweb"
|
|
DATASET_SPLIT = "train"
|
|
TOKENIZER_NAME = "HuggingFaceTB/SmolLM2-1.7B"
|
|
TARGET_TOKENS = 6_000_000_000 # 6 Billion tokens
|
|
OUTPUT_DIR = "../dataset"
|
|
CHUNK_SIZE = 50_000 # Number of documents to collect before writing a batch
|
|
|
|
def download_and_save_in_chunks():
|
|
"""Streams Fineweb, tokenizes, and saves to Parquet in memory-efficient chunks."""
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
OUTPUT_FILE_PATH = os.path.join(OUTPUT_DIR, "data.parquet")
|
|
|
|
# 1. Load Tokenizer & Dataset in Streaming Mode
|
|
print(f"Loading tokenizer: {TOKENIZER_NAME}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
|
|
|
print(f"Loading dataset: {FINETUNING_DATASET} in streaming mode...")
|
|
streaming_dataset = load_dataset(FINETUNING_DATASET, split=DATASET_SPLIT, streaming=True)
|
|
|
|
# 2. Define Schema (Crucial for ParquetWriter)
|
|
# Based on Fineweb's fields: 'text', 'meta', 'id'
|
|
schema = pa.schema([
|
|
pa.field('text', pa.string()),
|
|
pa.field('meta', pa.struct([
|
|
pa.field('url', pa.string()),
|
|
pa.field('dump', pa.string()),
|
|
pa.field('s_cluster', pa.int64()),
|
|
pa.field('token_count', pa.int64()), # Use their count, but rely on ours for stopping
|
|
])),
|
|
pa.field('id', pa.string())
|
|
])
|
|
|
|
current_tokens = 0
|
|
collected_batch = []
|
|
|
|
# Initialize Parquet Writer
|
|
writer = None
|
|
|
|
print("\n--- Starting Stream and Chunked Write to Disk ---")
|
|
start_time = time.time()
|
|
|
|
try:
|
|
for i, example in enumerate(streaming_dataset):
|
|
# Tokenize and check limit
|
|
tokens = len(tokenizer.encode(example['text']))
|
|
|
|
if current_tokens + tokens > TARGET_TOKENS:
|
|
print("Token limit reached! Stopping stream.")
|
|
break
|
|
|
|
collected_batch.append(example)
|
|
current_tokens += tokens
|
|
|
|
# 3. Write Batch to Disk when CHUNK_SIZE is reached
|
|
if (i + 1) % CHUNK_SIZE == 0:
|
|
print(f"Writing batch of {CHUNK_SIZE:,} documents...")
|
|
|
|
# Convert list of dicts to PyArrow Table
|
|
# We extract the 'meta' fields to match the schema structure
|
|
meta_list = [d.pop('meta', {}) for d in collected_batch]
|
|
|
|
# Flatten the data structure for PyArrow Table creation
|
|
flat_data = {
|
|
'text': [d['text'] for d in collected_batch],
|
|
'id': [d['id'] for d in collected_batch],
|
|
# Recreate the structured meta column
|
|
'meta': pa.array(meta_list, type=schema.field('meta').type)
|
|
}
|
|
|
|
# Create PyArrow Table
|
|
table = pa.Table.from_arrays(
|
|
[pa.array(flat_data['text']), flat_data['meta'], pa.array(flat_data['id'])],
|
|
schema=schema
|
|
)
|
|
|
|
if writer is None:
|
|
# Initialize writer on first run
|
|
writer = pq.ParquetWriter(OUTPUT_FILE_PATH, table.schema, compression='SNAPPY')
|
|
|
|
# Write the batch to disk
|
|
writer.write_table(table)
|
|
|
|
# Clear the batch list to free memory
|
|
collected_batch = []
|
|
|
|
print(f"Total documents written so far: {i + 1:,} | Total tokens: {current_tokens:,}")
|
|
|
|
finally:
|
|
# 4. Final Write (any remaining documents) and Cleanup
|
|
if collected_batch:
|
|
# Handle the last, incomplete batch
|
|
meta_list = [d.pop('meta', {}) for d in collected_batch]
|
|
|
|
flat_data = {
|
|
'text': [d['text'] for d in collected_batch],
|
|
'id': [d['id'] for d in collected_batch],
|
|
'meta': pa.array(meta_list, type=schema.field('meta').type)
|
|
}
|
|
|
|
table = pa.Table.from_arrays(
|
|
[pa.array(flat_data['text']), flat_data['meta'], pa.array(flat_data['id'])],
|
|
schema=schema
|
|
)
|
|
|
|
if writer is None:
|
|
# Handle the edge case where the total is less than CHUNK_SIZE
|
|
writer = pq.ParquetWriter(OUTPUT_FILE_PATH, table.schema, compression='SNAPPY')
|
|
|
|
writer.write_table(table)
|
|
|
|
if writer:
|
|
writer.close()
|
|
|
|
end_time = time.time()
|
|
print("\n--- Final Result ---")
|
|
print(f"✅ Successfully created Parquet file: {os.path.abspath(OUTPUT_FILE_PATH)}")
|
|
print(f"Final token count (approx): {current_tokens:,}")
|
|
print(f"Total time for streaming and saving: {end_time - start_time:.2f} seconds")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
download_and_save_in_chunks()
|