#!/usr/bin/python # -*- coding: latin-1 *-* import argparse import csv import tempfile import os # SCRIPT PARAMETERS parser = argparse.ArgumentParser(description='Cross validation script.') parser.add_argument('--training', required=True, help='File in CSV format containing training examples.') parser.add_argument('--script', required=True, help='The script that performs classification.') parser.add_argument('--options', required=False, help='Other options to pass to the classifier script.') opt = vars(parser.parse_args()) # DETERMINE TARGET CLASS targetClass='' with open(opt['training']) as f: reader = csv.DictReader(f, delimiter='\t') # LOAD ATTRIBUTE TYPES (continuous, discrete, ignore) attributeClass = reader.next() # DETERMINE TARGET CLASS classLine = reader.next() for i in attributeClass: if classLine[i]=='class': targetClass = i # LOAD KNOWN DATA with open(opt['training']) as f: data = f.readlines() header = data.pop(0) attributeClassLine = data.pop(0) classLine = data.pop(0) # ITERATE OVER DATA TO GENERATE TRAINING AND TEST SETS, AND RUN CASSIFIER trainingFileName = tempfile.mktemp() testFileName = tempfile.mktemp() outputFileName = tempfile.mktemp() tests=0 errors=0 TP=0; FP=0; PP=0; P=0; TN=0; FN=0; PN=0; N=0 for i in range(len(data)): trainingFile=open(trainingFileName, 'w') trainingFile.write(header) trainingFile.write(attributeClassLine) trainingFile.write(classLine) testFile=open(testFileName, 'w') testFile.write(header) for j in range(len(data)): if i==j: testFile.write(data[j]) else: trainingFile.write(data[j]) trainingFile.close() testFile.close() # RUN CLASSIFIER cmd=opt['script']+' --training '+trainingFileName+' --sample '+testFileName+opt['options']+' > '+outputFileName #~ print cmd os.system(cmd) with open(outputFileName) as f: reader=csv.DictReader(f, delimiter='\t') for row in reader: tests+=1 # TO DO: update PP, FP, TP, ... #~ if row[target] == row['prediction'] #~ ... os.remove(trainingFileName) os.remove(testFileName) os.remove(outputFileName) #~ print 'Tests:', tests, ', Errors:', errors, ', Error rate:', round(float(errors)/tests*100,1),'%' # TO DO # compute and print sensitivity, specificity, precision & accuracy