最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - How to split a dataset in train, validation and test based on the value of another column - Stack Overflow

programmeradmin1浏览0评论

Given a dataset of the form:

         date      user   f1     f2       rank   rank_group  counts
   0  09/09/2021  USER100  59.0  3599.9    1         1.0       3
   1  10/09/2021  USER100  75.29 80790.0   2         1.0       3
   2  11/09/2021  USER100  75.29 80790.0   3         1.0       3
   1  10/09/2021  USER100  75.29 80790.0   2         2.0       3
   2  11/09/2021  USER100  75.29 80790.0   3         2.0       3
   3  12/09/2021  USER100  75.29 80790.0   4         2.0       3
   2  11/09/2021  USER100  75.29 80790.0   3         3.0       3
   3  12/09/2021  USER100  75.29 80790.0   4         3.0       3
   4  13/09/2021  USER100  75.29 80790.0   5         3.0       3
   3  12/09/2021  USER100  75.29 80790.0   4         4.0       3
   4  13/09/2021  USER100  75.29 80790.0   5         4.0       3
   5  14/09/2021  USER100  75.29 80790.0   6         4.0       3
   4  13/09/2021  USER100  75.29 80790.0   5         5.0       3
   5  14/09/2021  USER100  75.29 80790.0   6         5.0       3
   6  15/09/2021  USER100  71.24 28809.9   7         5.0       3
   5  14/09/2021  USER100  75.29 80790.0   6         6.0       3
   6  15/09/2021  USER100  71.24 28809.9   7         6.0       3
   7  16/09/2021  USER100  71.31 79209.9   8         6.0       3
   6  15/09/2021  USER100  71.24 28809.9   7         7.0       3
   7  16/09/2021  USER100  71.31 79209.9   8         7.0       3
   8  17/09/2021  USER100  70.43 82809.9   9         7.0       3
   7  16/09/2021  USER100  71.31 79209.9   8         8.0       3
   8  17/09/2021  USER100  70.43 82809.9   9         8.0       3
   9  18/09/2021  USER100  68.65 82809.9   10        8.0       3

Given that rank_group indicates that the dataset has got 8 groups. I would like to split into a three dataset (train, validation and test) with the rate of 70%, 20%, 10% respectively. In this case, I would expect that train_set contains all the rows in corresponding rank_group=1.0,2.0,3.0,4.0,5.0. the validation_set contains all the rows in corresponding to the rank_group=6.0,7.0 and test_set contains all the rows in corresponding to rank_group=8.0.

Approach I: using split from numpy

  • train, validation, test = np.split(user_dataset, [int(.7*len(user_dataset)), int(.2*len(user_dataset)), int(.1*len(user_dataset))])

Approach II: using ad-hoc split

        `max_rank_group = user_dataset[rank_group].max() 

        train_number = round(max_rank_group * train_rate)
        validation_number = round((max_rank_group-train_number) * validation_rate)
        test_number = round((max_rank_group-validation_number) * test_rate)
        
        print('train_number ', train_number)
        print('validation_number ', validation_number)
        print('test_number ', test_number)
        
        print(' ')
        
        train_number_frac = train_number % 1
        validation_number_frac = validation_number % 1
        test_number_frac = train_number % 1
        
        current_train_rank_list = []
        if train_number_frac >= 0.5:
            current_train_rank_list = range(1, train_number+1)
        else:
            current_train_rank_list = range(1, train_number)
        
        current_validation_rank_list = []
        if validation_number_frac >= 0.5 and (train_number+validation_number+2) < max_rank_group:
            current_validation_rank_list = range(train_number, train_number+validation_number+2)
        else:
            current_validation_rank_list = range(train_number, train_number+validation_number+1)
        
        current_test_rank_list = []
        if test_number_frac >= 0.5 and (train_number+validation_number+test_number+2)<max_rank_group:
            current_test_rank_list = range(train_number+validation_number, train_number+validation_number+test_number+2)
        else:
            current_test_rank_list = range(train_number+validation_number, train_number+validation_number+test_number+1)
        
        
        
        
        #current_validation_rank_list = range(train_number, train_number+validation_number)
        #current_test_rank_list = range(train_number+validation_number, train_number+validation_number+test_number)
        
        print('current_train_rank_list ', current_train_rank_list)
        print('current_validation_rank_list ', current_validation_rank_list)
        print('current_test_rank_list ', current_test_rank_list)
        print(' ')`

Please, any help?

Given a dataset of the form:

         date      user   f1     f2       rank   rank_group  counts
   0  09/09/2021  USER100  59.0  3599.9    1         1.0       3
   1  10/09/2021  USER100  75.29 80790.0   2         1.0       3
   2  11/09/2021  USER100  75.29 80790.0   3         1.0       3
   1  10/09/2021  USER100  75.29 80790.0   2         2.0       3
   2  11/09/2021  USER100  75.29 80790.0   3         2.0       3
   3  12/09/2021  USER100  75.29 80790.0   4         2.0       3
   2  11/09/2021  USER100  75.29 80790.0   3         3.0       3
   3  12/09/2021  USER100  75.29 80790.0   4         3.0       3
   4  13/09/2021  USER100  75.29 80790.0   5         3.0       3
   3  12/09/2021  USER100  75.29 80790.0   4         4.0       3
   4  13/09/2021  USER100  75.29 80790.0   5         4.0       3
   5  14/09/2021  USER100  75.29 80790.0   6         4.0       3
   4  13/09/2021  USER100  75.29 80790.0   5         5.0       3
   5  14/09/2021  USER100  75.29 80790.0   6         5.0       3
   6  15/09/2021  USER100  71.24 28809.9   7         5.0       3
   5  14/09/2021  USER100  75.29 80790.0   6         6.0       3
   6  15/09/2021  USER100  71.24 28809.9   7         6.0       3
   7  16/09/2021  USER100  71.31 79209.9   8         6.0       3
   6  15/09/2021  USER100  71.24 28809.9   7         7.0       3
   7  16/09/2021  USER100  71.31 79209.9   8         7.0       3
   8  17/09/2021  USER100  70.43 82809.9   9         7.0       3
   7  16/09/2021  USER100  71.31 79209.9   8         8.0       3
   8  17/09/2021  USER100  70.43 82809.9   9         8.0       3
   9  18/09/2021  USER100  68.65 82809.9   10        8.0       3

Given that rank_group indicates that the dataset has got 8 groups. I would like to split into a three dataset (train, validation and test) with the rate of 70%, 20%, 10% respectively. In this case, I would expect that train_set contains all the rows in corresponding rank_group=1.0,2.0,3.0,4.0,5.0. the validation_set contains all the rows in corresponding to the rank_group=6.0,7.0 and test_set contains all the rows in corresponding to rank_group=8.0.

Approach I: using split from numpy

  • train, validation, test = np.split(user_dataset, [int(.7*len(user_dataset)), int(.2*len(user_dataset)), int(.1*len(user_dataset))])

Approach II: using ad-hoc split

        `max_rank_group = user_dataset[rank_group].max() 

        train_number = round(max_rank_group * train_rate)
        validation_number = round((max_rank_group-train_number) * validation_rate)
        test_number = round((max_rank_group-validation_number) * test_rate)
        
        print('train_number ', train_number)
        print('validation_number ', validation_number)
        print('test_number ', test_number)
        
        print(' ')
        
        train_number_frac = train_number % 1
        validation_number_frac = validation_number % 1
        test_number_frac = train_number % 1
        
        current_train_rank_list = []
        if train_number_frac >= 0.5:
            current_train_rank_list = range(1, train_number+1)
        else:
            current_train_rank_list = range(1, train_number)
        
        current_validation_rank_list = []
        if validation_number_frac >= 0.5 and (train_number+validation_number+2) < max_rank_group:
            current_validation_rank_list = range(train_number, train_number+validation_number+2)
        else:
            current_validation_rank_list = range(train_number, train_number+validation_number+1)
        
        current_test_rank_list = []
        if test_number_frac >= 0.5 and (train_number+validation_number+test_number+2)<max_rank_group:
            current_test_rank_list = range(train_number+validation_number, train_number+validation_number+test_number+2)
        else:
            current_test_rank_list = range(train_number+validation_number, train_number+validation_number+test_number+1)
        
        
        
        
        #current_validation_rank_list = range(train_number, train_number+validation_number)
        #current_test_rank_list = range(train_number+validation_number, train_number+validation_number+test_number)
        
        print('current_train_rank_list ', current_train_rank_list)
        print('current_validation_rank_list ', current_validation_rank_list)
        print('current_test_rank_list ', current_test_rank_list)
        print(' ')`

Please, any help?

Share Improve this question edited Feb 4 at 16:22 Carlo Allocca asked Feb 4 at 16:06 Carlo AlloccaCarlo Allocca 6611 gold badge8 silver badges20 bronze badges
Add a comment  | 

2 Answers 2

Reset to default 1

Just do subsets specifying the condition using the column rank_group:

import pandas as pd
df = pd.DataFrame({'rank_group':[1,1,2,2,2,2,3,3,3,4,4,4,5,5,6,6,6,7,7,8,8,8]})

train, validation, test = df[df['rank_group'] <6], df[df['rank_group'].isin([6,7])], df[df['rank_group'] >7]

Or generalising for 70, 20 and 10%:

max_rank_group = df['rank_group'].max()

train_ratio, val_ratio, test_ratio = 0.7, 0.2, 0.1

train_threshold = round(max_rank_group * train_ratio)
val_threshold = round(max_rank_group * val_ratio)


train = df[df['rank_group'] < train_threshold] # Below train threshold
validation = df[(df['rank_group'] >= train_threshold) & (df['rank_group'] < train_threshold + val_threshold)] # Among train and test thresholds
test = df[df['rank_group'] >= train_threshold + val_threshold] # Above train and validation thresholds

you can go with pandas module

import pandas as pd

train_set = df[df['rank_group'].isin([1.0, 2.0, 3.0, 4.0, 5.0])] validation_set = df[df['rank_group'].isin([6.0, 7.0])] test_set = df[df['rank_group'] == 8.0]

and if you want randomness you can go with sklearn.model_selection.train_test_split and splitting the test set again for validation.

与本文相关的文章

发布评论

评论列表(0)

  1. 暂无评论