diff --git a/preprocessing.py b/preprocessing.py
index d73c607f9b16f78f639b25c726ba41e894dab78c..bdc0c51deccbecdcc3bbb911675e9572981cefcb 100644
--- a/preprocessing.py
+++ b/preprocessing.py
@@ -13,7 +13,9 @@ Functions:
 
     preprocess_data(df: pd.DataFrame, target_column: str,
         sampling_strategy: str, test_size: float, 
-        random_state: int) 
+        random_state: int,
+        oversample_on_test_data: bool
+        ) 
         ->
         Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 
@@ -37,17 +39,27 @@ def load_data(filepath="OneHotEncoder_2.csv"):
     return pd.read_csv(filepath)
 
 # Function to preprocess data
-def preprocess_data(df, target_column='Mortality_All', sampling_strategy='minority', test_size=0.30, random_state=0):
+def preprocess_data(df, target_column='Mortality_All', sampling_strategy='minority', test_size=0.30, random_state=0, oversample_on_test_data=False):
     # Define the predicted target
     X = df.drop(target_column, axis=1)
     y = df[target_column]
 
     # Define the oversampling strategy for balancing between the two classes
     oversample = RandomOverSampler(sampling_strategy=sampling_strategy)
-    X_ROS, y_ROS = oversample.fit_resample(X, y)
 
-    # Split our data into training and testing sets
-    X_train, X_test, y_train, y_test = train_test_split(X_ROS, y_ROS, test_size=test_size, random_state=random_state)
+    if oversample_on_test_data:
+        # Oversample whole data set
+        X_ROS, y_ROS = oversample.fit_resample(X, y)
+
+        # Split our data into training and testing sets
+        X_train, X_test, y_train, y_test = train_test_split(X_ROS, y_ROS, test_size=test_size, random_state=random_state)
+
+    else:
+        # Split our data into training and testing sets
+        X_train_imb, X_test, y_train_imb, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
+
+        # Oversample training set only
+        X_train, y_train = oversample.fit_resample(X_train_imb, y_train_imb)
 
     # Use StandardScaler for our data
     scaler = StandardScaler()