fix features

This commit is contained in:
Alexis REVELLE 2024-03-28 07:56:36 +00:00
parent adc62dd056
commit ebdbacbe34

View File

@ -28,7 +28,7 @@ import warnings
def load_train_test(type_of_activity): def load_train_test(type_of_activity):
BUCKET = f"projet-bdc2324-team1/Generalization/{type_of_activity}" BUCKET = f"projet-bdc2324-team1/Generalization_v2/{type_of_activity}"
File_path_train = BUCKET + "/Train_set.csv" File_path_train = BUCKET + "/Train_set.csv"
File_path_test = BUCKET + "/Test_set.csv" File_path_test = BUCKET + "/Test_set.csv"
@ -83,9 +83,7 @@ def compute_recall_companies(dataset_test, y_pred, type_of_activity, model):
def features_target_split(dataset_train, dataset_test): def features_target_split(dataset_train, dataset_test):
features_l = ['nb_tickets', 'nb_purchases', 'total_amount', 'nb_suppliers', 'vente_internet_max', 'purchase_date_min', 'purchase_date_max', features_l = [']
'time_between_purchase', 'nb_tickets_internet', 'is_email_true', 'opt_in', #'is_partner',
'gender_female', 'gender_male', 'gender_other', 'nb_campaigns', 'nb_campaigns_opened', 'country_fr']
X_train = dataset_train[features_l] X_train = dataset_train[features_l]
y_train = dataset_train[['y_has_purchased']] y_train = dataset_train[['y_has_purchased']]
@ -94,30 +92,29 @@ def features_target_split(dataset_train, dataset_test):
return X_train, X_test, y_train, y_test return X_train, X_test, y_train, y_test
def preprocess(type_of_model): def preprocess(type_of_model, type_of_activity):
numeric_features = ['nb_campaigns', 'taux_ouverture_mail', 'prop_purchases_internet', 'nb_tickets', 'nb_purchases', 'total_amount', 'nb_suppliers',
'purchases_10_2021','purchases_10_2022', 'purchases_11_2021', 'purchases_12_2021','purchases_1_2022', 'purchases_2_2022', 'purchases_3_2022',
'purchases_4_2022', 'purchases_5_2021', 'purchases_5_2022', 'purchases_6_2021', 'purchases_6_2022', 'purchases_7_2021', 'purchases_7_2022', 'purchases_8_2021',
'purchases_8_2022','purchases_9_2021', 'purchases_9_2022', 'purchase_date_min', 'purchase_date_max', 'nb_targets']
binary_features = ['gender_female', 'gender_male', 'country_fr', 'achat_internet', 'categorie_age_0_10', 'categorie_age_10_20', 'categorie_age_20_30','categorie_age_30_40',
'categorie_age_40_50', 'categorie_age_50_60', 'categorie_age_60_70', 'categorie_age_70_80', 'categorie_age_plus_80','categorie_age_inconnue',
'country_fr', 'is_profession_known', 'is_zipcode_known', 'opt_in']
if type_of_model=='premium': if type_of_model=='premium':
numeric_features = ['nb_tickets', 'nb_purchases', 'total_amount', 'nb_suppliers', 'vente_internet_max', if type_of_activity=='musique':
'purchase_date_min', 'purchase_date_max', 'time_between_purchase', 'nb_tickets_internet', binary_features.extend(['target_optin', 'target_newsletter'])
'nb_campaigns', 'nb_campaigns_opened'] elif type_of_activity=='sport':
binary_features.extend(['target_jeune', 'target_entreprise', 'target_abonne'])
else:
binary_features.extend([ 'target_scolaire', 'target_entreprise', 'target_famille', 'target_newsletter'])
binary_features = ['gender_female', 'gender_male', 'gender_other', 'country_fr']
categorical_features = ['opt_in']
else:
numeric_features = ['nb_tickets', 'nb_purchases', 'total_amount', 'nb_suppliers', 'vente_internet_max',
'purchase_date_min', 'purchase_date_max', 'time_between_purchase', 'nb_tickets_internet',
'nb_campaigns', 'nb_campaigns_opened']
binary_features = ['gender_female', 'gender_male', 'gender_other', 'country_fr']
categorical_features = ['opt_in']
numeric_transformer = Pipeline(steps=[ numeric_transformer = Pipeline(steps=[
("scaler", StandardScaler()) ("scaler", StandardScaler())
]) ])
categorical_features = ['opt_in']
categorical_transformer = Pipeline(steps=[
("onehot", OneHotEncoder(handle_unknown='ignore', sparse_output=False))
])
binary_transformer = Pipeline(steps=[ binary_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="most_frequent")), ("imputer", SimpleImputer(strategy="most_frequent")),
@ -125,7 +122,6 @@ def preprocess(type_of_model):
preproc = ColumnTransformer( preproc = ColumnTransformer(
transformers=[ transformers=[
("num", numeric_transformer, numeric_features), ("num", numeric_transformer, numeric_features),
("cat", categorical_transformer, categorical_features),
("bin", binary_transformer, binary_features) ("bin", binary_transformer, binary_features)
] ]
) )