diff --git a/0_5_Machine_Learning.py b/0_5_Machine_Learning.py index b893aed..acbf790 100644 --- a/0_5_Machine_Learning.py +++ b/0_5_Machine_Learning.py @@ -42,7 +42,7 @@ type_of_model = input('Choisissez le type de model : basique ? premium ?') S3_ENDPOINT_URL = "https://" + os.environ["AWS_S3_ENDPOINT"] fs = s3fs.S3FileSystem(client_kwargs={'endpoint_url': S3_ENDPOINT_URL}) -dataset_train, dataset_test = load_train_test(type_of_activity) +dataset_train, dataset_test = load_train_test(type_of_activity, type_of_model) X_train, X_test, y_train, y_test = features_target_split(dataset_train, dataset_test)