Tune your LightGBM model: Avoid memory spikes and train more efficiently
Author
Andreas Wagenmann
Published
May 5, 2025
Figure 1: Leaner Sampling
What is this about?
Anyone who has trained a LightGBM model with large amount of data will observe a significant initial memory spike. This is due to sampling by LightGBM to construct the dataset to train on. This is problematic since the memory need for training LightGBM models is mostly determined by this initial step.
This article describes a way to extract the sampling step to occur before the training and a way to process this more efficiently utiliying partitioned HDF5 files and the lightGBM dataset API.
The general principle of the solution is simple:
Store source data in a format that allows to sample the n-th sample without needing to load the whole data into memory (HDF5 binary format)
Use lightGBM DataSet API to utilize the HDF5 files for the sampling procedure
Store the final DataSet (binary) to be used for training
Start your training
As per the official webpage of the python interface to the (Official h5py Website), HDF5 provides the following:
HDF5 lets you store huge amounts of numerical data, and easily manipulate that data from NumPy. For example, you can slice into multi-terabyte datasets stored on disk, as if they were real NumPy arrays. Thousands of datasets can be stored in a single file, categorized and tagged however you want.
What is demonstrated in the article?
The article will show the process of converting DataFrames to HDF5 files and subsequently using them for the DataSet generation as needed for LightGBM training.
The example utilizes Microsofts Learning To Rank Datasets, specifically the WEB10K variant (MSLR Project Page).
The shown hdf5 / DataSet specific code is a modified variant of but heaviliy based on an example provided by the LightGBM crew (LightGBM Code Sample) with some adjustments and additional comments.
Update the below settings to your local file paths
Execute the provided code (NOTE: this will store data such as hdf5 files and dataset binaries along with the trained models in the same folder where your code is located)
Code
# NOTE: update the paths to the location where you stored the MSLR-WEB10K data# (or another variant of the MSLR datasets)import getpassdata_path =f"/Users/{getpass.getuser()}/DATA/LTR/MSLR-WEB10K/Fold1"train_path =f"{data_path}/train.txt"test_path =f"{data_path}/test.txt"validation_path =f"{data_path}/vali.txt"
Needed Imports
Code
import sysimport mathfrom sklearn.datasets import load_svmlight_fileimport numpy as npimport pandas as pdimport lightgbm as lgbimport h5pyfrom pathlib import Pathfrom typing import Dict, List, Unionimport json
Code for DataSet creation utilizing hdf5
Data structure for retrieving sample N from HDF5. Needed for memory-efficient DataSet creation:
Code
class HDFSequence(lgb.Sequence):""" Construct a sequence object from HDF5 with required interface. Parameters ---------- hdf_dataset : h5py.Dataset Dataset in HDF5 file. batch_size : int Size of a batch. When reading data to construct lightgbm Dataset, each read reads batch_size rows. HDF Sequence which takes hdf datasets from a file, e.g via h5py.File(file_path, "r")[dataset_key], where dataset_key would be "X" if the dataset was stored as "X" on creation of the hdf5 file. This allows us to draw batches from multiple files, allowing to process data that is larger than memory, and batch-wise construct the dataset needed for training. Example: f = h5py.File('train.hdf5', 'r') train_data = lgb.Dataset(HDFSequence(f['X'], 8192), label=f['Y'][:]) """def__init__(self, hdf_dataset, batch_size):self.data = hdf_datasetself.batch_size = batch_sizedef__getitem__(self, idx):returnself.data[idx]def__len__(self):returnlen(self.data)
Helpfer functions for conversions numpy array to HDF5, DataFrame to hdf5, reading of hdf5 data:
Code
def save2hdf(input_data: Dict[str, any], file_path: str, chunk_size: int):"""Store numpy array to HDF5 file. Please note chunk size settings in the implementation for I/O performance optimization. """with h5py.File(file_path, "w") as f:for name, data in input_data.items(): nrow, ncol = data.shapeif ncol ==1:# Y has a single column and we read it in single shot. So store it as an 1-d array. chunk = (nrow,) data = data.values.flatten()else:# We use random access for data sampling when creating LightGBM Dataset from Sequence.# When accessing any element in a HDF5 chunk, it's read entirely.# To save I/O for sampling, we should keep number of total chunks much larger than sample count.# Here we are just creating a chunk size that matches with batch_size.## Also note that the data is stored in row major order to avoid extra copy when passing to# lightgbm Dataset. chunk = (chunk_size, ncol) f.create_dataset(name, data=data, chunks=chunk, compression="lzf")def store_df_as_hdf5(x_df: pd.DataFrame, y_df: pd.DataFrame, groups: Union[List[str], pd.Series, np.ndarray], feature_names: Union[None, List[str], pd.Series, np.ndarray], categorical_features: Union[None, List[str], pd.Series, np.ndarray], file_path: str, chunk_size: int):""" Store features, targets and corresponding groups in a hdf5 file to pull them out during dataset creation.NOTE: hdf5 files also allow to store attributes for each dataset. - Setting an attribute value: f[dataset_key].attrs[attr_name] = attr_value - After loading of hdf5 file, can check which attributes are set: list(f[dataset_key].attrs.keys()) - Select attribute: f[dataset_key].attrs[attr_name] Note that in the save2hdf function used chunks are defined, which are important to optimize IO (on accessing an element, the whole chunk it belongs to is read, thus its advised to have the number of chunks much larger than the bin_construct_sample_cnt parameter set on dataset creation (default is 200000) """ store_dict = {"X": x_df,"Y": y_df,"groups": pd.DataFrame(groups) }if feature_names isnotNone: store_dict["feature_names"] = pd.DataFrame(feature_names)if categorical_features isnotNone: store_dict["categorical_features"] = pd.DataFrame(categorical_features) save2hdf(input_data=store_dict, file_path=file_path, chunk_size=chunk_size)def read_hdf5_data(path_to_hdf5_file: str):""" Read a hdf5 file. After creating the hdf5 file object, can do: - get available dataset keys with list(f.keys()) and respective dataset with f[dataset_key] - get attributes for dataset with list(f[dataset_key].attrs.keys()) and specific attr value with f[dataset_key].attrs[attr_key]. - access the data by the basic range, index operations and such (e.g f[dataset_key][:]) """return h5py.File(path_to_hdf5_file, 'r')
Code for creating a lightgbm DataSet from multiple hdf5 file paths:
Code
def create_dataset_from_multiple_hdf(input_flist: List[str], batch_size: int, output_file: str, bin_construct_sample_cnt: int=200000, max_bin: int=255, reference_dataset: lgb.Dataset =None):""" Assumes in each hdf5 file the features are stored in dataset with following datasets: - features under key `X` - targets under key `Y` - groups under key `groups` - feature_names under key `feature_names` - categorical_features under key `categorical_features` With binary dataset created, we can use either Python API or cmdline version to train, saving the dataset preparation step and memory since by using HDFSequence here the dataset is created iteratively by pulling batches of data rather than full data in memory. Dataset will use column names like ["0", "1", "2", ...] """ data = [] ylist = [] grouplist = [] features =None categorical_features =Nonefor f in input_flist: f = h5py.File(f, "r")# features and categorical features are not specific to a each dataset, thus we assume it is the same order in # all and just pull it onceif features isNoneand"feature_names"inlist(f.keys()): features = f["feature_names"]if categorical_features isNoneand"categorical_features"inlist(f.keys()): categorical_features = f["categorical_features"] data.append(HDFSequence(f["X"], batch_size)) ylist.append(f["Y"][:]) grouplist.append(f["groups"][:])# these are also the defaults right now# note that if increasing max_bin, this will also have effect on the storage type and thus size of the features,# e.g for valueof 255, uint8_t will be used, for 256 it would already need higher accuracy type params = {"bin_construct_sample_cnt": bin_construct_sample_cnt,"max_bin": max_bin } y = np.concatenate(ylist) groups = np.concatenate(grouplist)if categorical_features isNone: categorical_features ="auto"if features isNone: features ="auto" dataset = lgb.Dataset(data, label=y, params=params, group=groups, feature_name=features, categorical_feature=categorical_features, reference=reference_dataset) dataset.save_binary(output_file)return dataset
Loading file in svmlight format (since this is the format of the Microsoft LTR DataSets):
Use hdf5 files to generate datasets iteratively by loading batches from hdf5 files
This step allows creating datasets from data that would not fit into memory all at once. The created datasets occupy only around 10 % of the storage size of the original data.
Code
# NOTE: for datasets to be used as validation data, the original train data has to be set as reference, otherwise# we might see errors if running lgb.train with valid_setstrain_dataset = create_dataset_from_multiple_hdf(input_flist = ["./train_data.hdf5"], batch_size =8192, output_file ="./train_data.bin")create_dataset_from_multiple_hdf(input_flist = ["./test_data.hdf5"], batch_size =8192, output_file ="./test_data.bin")create_dataset_from_multiple_hdf(input_flist = ["./validation_data.hdf5"], batch_size =8192, output_file ="./validation_data.bin", reference_dataset=train_dataset)
Fitting LightGBM Ranker
Note that the provided settings are solely for presentation purposes and not optimized in any way.
Moving the DataSet creation before the training step and utilizing the HDF5 file format and the LightGBM DataSet API can significantly reduce resource needs while training. This will allow you to take smaller instances, since often the memory needs for training are determined solely by the initial step of DataSet sampling / creation.