fix premium
This commit is contained in:
parent
eb87cc6998
commit
02a4ea20dd
|
@ -27,7 +27,7 @@ import pickle
|
|||
import warnings
|
||||
|
||||
|
||||
def load_train_test(type_of_activity):
|
||||
def load_train_test(type_of_activity, type_of_model):
|
||||
BUCKET = f"projet-bdc2324-team1/Generalization_v2/{type_of_activity}"
|
||||
File_path_train = BUCKET + "/Train_set.csv"
|
||||
File_path_test = BUCKET + "/Test_set.csv"
|
||||
|
@ -40,6 +40,11 @@ def load_train_test(type_of_activity):
|
|||
dataset_test = pd.read_csv(file_in, sep=",")
|
||||
# dataset_test['y_has_purchased'] = dataset_test['y_has_purchased'].fillna(0)
|
||||
|
||||
if type_of_model=='premium':
|
||||
dataset_train['company'] = dataset_train['customer_id'].apply(lambda x: x.split('_')[0])
|
||||
dataset_test['company'] = dataset_test['customer_id'].apply(lambda x: x.split('_')[0])
|
||||
dataset_train = dataset_train[dataset_train['company'].isin(['1', '3', '4', '5', '6', '7', '8', '10', '11', '13'])]
|
||||
dataset_test = dataset_test[dataset_test['company'].isin(['1', '3', '4', '5', '6', '7', '8', '10', '11', '13'])]
|
||||
return dataset_train, dataset_test
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user