diff --git a/utils_ml.py b/utils_ml.py index 1955ef9..825f9c0 100644 --- a/utils_ml.py +++ b/utils_ml.py @@ -27,7 +27,7 @@ import pickle import warnings -def load_train_test(type_of_activity): +def load_train_test(type_of_activity, type_of_model): BUCKET = f"projet-bdc2324-team1/Generalization_v2/{type_of_activity}" File_path_train = BUCKET + "/Train_set.csv" File_path_test = BUCKET + "/Test_set.csv" @@ -39,7 +39,12 @@ def load_train_test(type_of_activity): with fs.open(File_path_test, mode="rb") as file_in: dataset_test = pd.read_csv(file_in, sep=",") # dataset_test['y_has_purchased'] = dataset_test['y_has_purchased'].fillna(0) - + + if type_of_model=='premium': + dataset_train['company'] = dataset_train['customer_id'].apply(lambda x: x.split('_')[0]) + dataset_test['company'] = dataset_test['customer_id'].apply(lambda x: x.split('_')[0]) + dataset_train = dataset_train[dataset_train['company'].isin(['1', '3', '4', '5', '6', '7', '8', '10', '11', '13'])] + dataset_test = dataset_test[dataset_test['company'].isin(['1', '3', '4', '5', '6', '7', '8', '10', '11', '13'])] return dataset_train, dataset_test