fix premium

This commit is contained in:
Alexis REVELLE 2024-03-28 11:19:05 +00:00
parent eb87cc6998
commit 02a4ea20dd

View File

@ -27,7 +27,7 @@ import pickle
import warnings
def load_train_test(type_of_activity):
def load_train_test(type_of_activity, type_of_model):
BUCKET = f"projet-bdc2324-team1/Generalization_v2/{type_of_activity}"
File_path_train = BUCKET + "/Train_set.csv"
File_path_test = BUCKET + "/Test_set.csv"
@ -39,7 +39,12 @@ def load_train_test(type_of_activity):
with fs.open(File_path_test, mode="rb") as file_in:
dataset_test = pd.read_csv(file_in, sep=",")
# dataset_test['y_has_purchased'] = dataset_test['y_has_purchased'].fillna(0)
if type_of_model=='premium':
dataset_train['company'] = dataset_train['customer_id'].apply(lambda x: x.split('_')[0])
dataset_test['company'] = dataset_test['customer_id'].apply(lambda x: x.split('_')[0])
dataset_train = dataset_train[dataset_train['company'].isin(['1', '3', '4', '5', '6', '7', '8', '10', '11', '13'])]
dataset_test = dataset_test[dataset_test['company'].isin(['1', '3', '4', '5', '6', '7', '8', '10', '11', '13'])]
return dataset_train, dataset_test