from typing import Dict, Any
from functools import partial
import argparse
import os
import time
import numpy as np
import pickle as pkl
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import MinMaxScaler
from qualitylib.tools import import_python_file, read_dataset
from qualitylib.feature_extractor import get_fex
from qualitylib.runner import Runner
from qualitylib.cross_validate import random_cross_validation
from feature_extractors.ssim_fex import SsimFeatureExtractor # Import feature extractor(s) to make visible to get_fex
np.random.seed(0)
class ScaledSVR:
def __init__(self, *svr_args, **svr_kwargs) -> None:
self.scaler = MinMaxScaler(feature_range=(-1, 1))
self.reg = SVR(*svr_args, **svr_kwargs)
def fit(self, X: np.ndarray, y: np.ndarray) -> None:
X_trans = self.scaler.fit_transform(X)
self.reg.fit(X_trans, y)
def predict(self, X: np.ndarray) -> np.ndarray:
return self.reg.predict(self.scaler.transform(X))
def print_agg_stats(stats: Dict[str, Any]) -> None:
sample_stats = stats[list(stats.keys())[0]]
num_samples = len(sample_stats)
lo_ci = (0.5 - 1.96*0.5/np.sqrt(num_samples))*100
hi_ci = (0.5 + 1.96*0.5/np.sqrt(num_samples))*100
# Each key in dict corresponds to one set of hyperparameters for the regressor.
# Find best hyperparameter based on median SROCC
maxval = -1
max_param_key = None
for param_key in stats:
key_stats = np.array([stat['SROCC'] for stat in stats[param_key]])
medval = np.median(key_stats)
if medval > maxval:
maxval = medval
max_param_key = param_key
print(f'Optimal param: {max_param_key}')
print('Stat,Median,LoCI,HiCI,Std') # Not using spaces makes parsing text output as csv easier
for stat_key in stats[max_param_key][0]:
key_stats = np.array([stat[stat_key] for stat in stats[max_param_key]])
print(f'{stat_key},{np.median(key_stats):.4f},{np.percentile(key_stats, lo_ci):.4f},{np.percentile(key_stats, hi_ci):.4f},{np.std(key_stats):.4f}')
def dict_to_str(d: Dict[Any, Any]) -> str:
s_arr = []
for key in d:
s_arr.extend([str(key), str(d[key])])
return '_'.join(s_arr)
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description='Conduct gridsearch crossvalidation')
parser.add_argument('--dataset', help='Path to dataset file for which to extract features', type=str)
parser.add_argument('--fex_name', help='Name of feature extractor', type=str)
parser.add_argument('--fex_version', help='Version of feature extractor', type=str, default=None)
parser.add_argument('--feat_names_file', help='Path to file containing feature sets', type=str, default=None)
parser.add_argument('--regressor', help='Regressor to use', type=str, default='RandomForest')
parser.add_argument('--splits', help='Number of parallel processes', type=int, default=100)
parser.add_argument('--processes', help='Number of parallel processes', type=int, default=100)
parser.add_argument('--out_file', help='Path to output pickle file', type=str, required=True)
return parser
def main() -> None:
args = get_parser().parse_args()
if os.path.isfile(args.out_file):
return
dataset = import_python_file(args.dataset)
assets = read_dataset(dataset, shuffle=True)
FexClass = get_fex(args.fex_name, args.fex_version)
runner = Runner(FexClass, processes=args.processes, use_cache=True) # Reads from stored results if available, else stores results.
if args.feat_names_file is not None:
mod = import_python_file(args.feat_names_file)
feat_names_dict = mod.feat_names_dict
else:
feat_names_dict = {'all': None}
if args.regressor == 'RandomForest':
ModelClass = RandomForestRegressor
model_params = [{'max_features': max_feat, 'n_estimators': n_est, 'n_jobs': max(100//args.processes, 1)} for max_feat in [0.25, 0.5, 0.75, 1.0] for n_est in [25, 50, 100, 200]]
elif args.regressor == 'LinearSVR':
ModelClass = ScaledSVR
model_params = [{'kernel': 'linear', 'C': c} for c in [1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3]]
elif args.regressor == 'GaussianSVR':
ModelClass = ScaledSVR
model_params = [{'kernel': 'rbf', 'C': c, 'gamma': gam} for c in [1, 1e1, 1e2, 1e3] for gam in [1e-6, 1e-4, 1e-2, 1]]
else:
raise ValueError('Invalid regressor')
res_dict = {}
for key in feat_names_dict:
results = runner(assets, return_results=True, feat_names=np.array(feat_names_dict[key]) if feat_names_dict[key] is not None else None) # Extract features and return only specified features for cross-validation.
start_time = time.time()
temp_res_dict = {}
for model_param_dict in model_params:
agg_stats = random_cross_validation(partial(ModelClass, **model_param_dict), results, splits=args.splits, test_fraction=0.2, processes=args.processes)
temp_res_dict[dict_to_str(model_param_dict)] = agg_stats['stats'] # Metrics computed from each split.
print(f'Tested params: {model_param_dict}. Time elapsed {((time.time() - start_time)/60):.2f} minutes.')
print(f'Results - {key}')
print_agg_stats(temp_res_dict)
res_dict[key] = temp_res_dict
with open(args.out_file, 'wb') as out_file:
pkl.dump(res_dict, out_file)
if __name__ == '__main__':
main()