Регистрация | Вход
import java.util.*;public class Knn { public static final String PATH_TO_DATA_FILE = "coupious.data"; public static final int NUM_ATTRS = 9; public static final int K = 262; public static final int CATEGORY_INDEX = 0; public static final int DISTANCE_INDEX = 1; public static final int EXPIRATION_INDEX = 2; public static final int HANDSET_INDEX = 3; public static final int OFFER_INDEX = 4; public static final int WSACTION_INDEX = 5; public static final int NUM_RUNS = 1000; public static double averageDistance = 0; public static void main(String[] args) { ArrayList instances = null; ArrayList distances = null; ArrayList neighbors = null; WSAction.Action classification = null; Instance classificationInstance = null; FileReader reader = null; int numRuns = 0, truePositives = 0, falsePositives = 0, falseNegatives = 0, trueNegatives = 0; double precision = 0, recall = 0, fMeasure = 0; falsePositives = 1; reader = new FileReader(PATH_TO_DATA_FILE); instances = reader.buildInstances(); do { classificationInstance = extractIndividualInstance(instances); distances = calculateDistances(instances, classificationInstance); neighbors = getNearestNeighbors(distances); classification = determineMajority(neighbors); System.out.println("Gathering " + K + " nearest neighbors to:"); printClassificationInstance(classificationInstance); printNeighbors(neighbors); System.out.println("\nExpected situation result for instance: " + classification.toString()); if(classification.toString().equals(((WSAction)classificationInstance.getAttributes().get(WSACTION_INDEX)).getAction().toString())) { truePositives++; } else { falseNegatives++; } numRuns++; instances.add(classificationInstance); } while(numRuns < NUM_RUNS); precision = ((double)(truePositives / (double)(truePositives + falsePositives))); recall = ((double)(truePositives / (double)(truePositives + falseNegatives))); fMeasure = ((double)(precision * recall) / (double)(precision + recall)); System.out.println("Precision: " + precision); System.out.println("Recall: " + recall); System.out.println("F-Measure: " + fMeasure); System.out.println("Average distance: " + (double)(averageDistance / (double)(NUM_RUNS * K))); } public static Instance extractIndividualInstance(ArrayList instances) { Random generator = new Random(new Date().getTime()); int random = generator.nextInt(instances.size() - 1); Instance singleInstance = instances.get(random); instances.remove(random); return singleInstance; } public static void printClassificationInstance(Instance classificationInstance) { for(Feature f : classificationInstance.getAttributes()) { System.out.print(f.getName() + ": "); if(f instanceof Category) { System.out.println(((Category)f).getCategory().toString()); } else if(f instanceof Distance) { System.out.println(((Distance)f).getDistance().toString()); } else if (f instanceof Expiration) { System.out.println(((Expiration)f).getExpiry().toString()); } else if (f instanceof Handset) { System.out.print(((Handset)f).getOs().toString() + ", "); System.out.println(((Handset)f).getDevice().toString()); } else if (f instanceof Offer) { System.out.println(((Offer)f).getOfferType().toString()); } else if (f instanceof WSAction) { System.out.println(((WSAction)f).getAction().toString()); } } } public static void printNeighbors(ArrayList neighbors) { int i = 0; for(Neighbor neighbor : neighbors) { Instance instance = neighbor.getInstance(); System.out.println("\nNeighbor " + (i + 1) + ", distance: " + neighbor.getDistance()); i++; for(Feature f : instance.getAttributes()) { System.out.print(f.getName() + ": "); if(f instanceof Category) { System.out.println(((Category)f).getCategory().toString()); } else if(f instanceof Distance) { System.out.println(((Distance)f).getDistance().toString()); } else if (f instanceof Expiration) { System.out.println(((Expiration)f).getExpiry().toString()); } else if (f instanceof Handset) { System.out.print(((Handset)f).getOs().toString() + ", "); System.out.println(((Handset)f).getDevice().toString()); } else if (f instanceof Offer) { System.out.println(((Offer)f).getOfferType().toString()); } else if (f instanceof WSAction) { System.out.println(((WSAction)f).getAction().toString()); } } } } public static WSAction.Action determineMajority(ArrayList neighbors) { int yea = 0, ney = 0; for(int i = 0; i < neighbors.size(); i++) { Neighbor neighbor = neighbors.get(i); Instance instance = neighbor.getInstance(); if(instance.isRedeemed()) { yea++; } else { ney++; } } if(yea > ney) { return WSAction.Action.Redeem; } else { return WSAction.Action.Hit; } } public static ArrayList getNearestNeighbors(ArrayList distances) { ArrayList neighbors = new ArrayList(); for(int i = 0; i < K; i++) { averageDistance += distances.get(i).getDistance(); neighbors.add(distances.get(i)); } return neighbors; } public static ArrayList calculateDistances(ArrayList instances, Instance singleInstance) { ArrayList distances = new ArrayList(); Neighbor neighbor = null; int distance = 0; for(int i = 0; i < instances.size(); i++) { Instance instance = instances.get(i); distance = 0; neighbor = new Neighbor(); // for each feature, go through and calculate the "distance" for(Feature f : instance.getAttributes()) { if(f instanceof Category) { Category.Categories cat = ((Category) f).getCategory(); Category singleInstanceCat = (Category)singleInstance.getAttributes().get(CATEGORY_INDEX); distance += Math.pow((cat.ordinal() - singleInstanceCat.getCategory().ordinal()), 2); } else if(f instanceof Distance) { Distance.DistanceRange dist = ((Distance) f).getDistance(); Distance singleInstanceDist = (Distance)singleInstance.getAttributes().get(DISTANCE_INDEX); distance += Math.pow((dist.ordinal() - singleInstanceDist.getDistance().ordinal()), 2); } else if (f instanceof Expiration) { Expiration.Expiry exp = ((Expiration) f).getExpiry(); Expiration singleInstanceExp = (Expiration)singleInstance.getAttributes().get(EXPIRATION_INDEX); distance += Math.pow((exp.ordinal() - singleInstanceExp.getExpiry().ordinal()), 2); } else if (f instanceof Handset) { // there are two calculations needed here, one for device, one for OS Handset.Device device = ((Handset) f).getDevice(); Handset singleInstanceDevice = (Handset)singleInstance.getAttributes().get(HANDSET_INDEX); distance += Math.pow((device.ordinal() - singleInstanceDevice.getDevice().ordinal()), 2); Handset.OS os = ((Handset) f).getOs(); Handset singleInstanceOs = (Handset)singleInstance.getAttributes().get(HANDSET_INDEX); distance += Math.pow((os.ordinal() - singleInstanceOs.getOs().ordinal()), 2); } else if (f instanceof Offer) { Offer.OfferType offer = ((Offer) f).getOfferType(); Offer singleInstanceOffer = (Offer)singleInstance.getAttributes().get(OFFER_INDEX); distance += Math.pow((offer.ordinal() - singleInstanceOffer.getOfferType().ordinal()), 2); } else if (f instanceof WSAction) { WSAction.Action action = ((WSAction) f).getAction(); WSAction singleInstanceAction = (WSAction)singleInstance.getAttributes().get(WSACTION_INDEX); distance += Math.pow((action.ordinal() - singleInstanceAction.getAction().ordinal()), 2); } else { System.out.println("Unknown category in distance calculation. Exiting for debug: " + f); System.exit(1); } } neighbor.setDistance(distance); neighbor.setInstance(instance); distances.add(neighbor); } for (int i = 0; i < distances.size(); i++) { for (int j = 0; j < distances.size() - i - 1; j++) { if(distances.get(j).getDistance() > distances.get(j + 1).getDistance()) { Neighbor tempNeighbor = distances.get(j); distances.set(j, distances.get(j + 1)); distances.set(j + 1, tempNeighbor); } } } return distances; }}
import java.io.IOException;import java.util.ArrayList;import java.util.Arrays;import java.util.HashMap;import java.util.Iterator;import java.util.Set;public class knn { public static void main(String[] args){ System.out.println("iris"); knn("classification\\iris_train.txt","classification\\iris_test.txt",1,2); System.out.println(); System.out.println("glass"); knn("classification\\glass_train.txt","classification\\glass_test.txt",1,0); System.out.println(); System.out.println("vowel"); knn("classification\\vowel_train.txt","classification\\vowel_test.txt",3,2); System.out.println(); System.out.println("vehicle"); knn("classification\\vehicle_train.txt","classification\\vehicle_test.txt",3,1); System.out.println(); System.out.println("letter"); knn("classification\\letter_train.txt","classification\\letter_test.txt",3,0); System.out.println(); System.out.println("DNA"); knn("classification\\dna_train.txt","classification\\dna_test.txt",5,2); System.out.println(); } public static void knn(String trainingFile, String testFile, int K, int metricType){ //get the current time final long startTime = System.currentTimeMillis(); // make sure the input arguments are legal if(K <= 0){ System.out.println("K should be larger than 0!"); return; } // metricType should be within [0,2]; if(metricType > 2 || metricType <0){ System.out.println("metricType is not within the range [0,2]. Please try again later"); return; } //TrainingFile and testFile should be the same group String trainGroup = extractGroupName(trainingFile); String testGroup = extractGroupName(testFile); if(!trainGroup.equals(testGroup)){ System.out.println("trainingFile and testFile are illegal!"); return; } try { //read trainingSet and testingSet TrainRecord[] trainingSet = FileManager.readTrainFile(trainingFile); TestRecord[] testingSet = FileManager.readTestFile(testFile); //determine the type of metric according to metricType Metric metric; if(metricType == 0) metric = new CosineSimilarity(); else if(metricType == 1) metric = new L1Distance(); else if (metricType == 2) metric = new EuclideanDistance(); else{ System.out.println("The entered metric_type is wrong!"); return; } //test those TestRecords one by one int numOfTestingRecord = testingSet.length; for(int i = 0; i < numOfTestingRecord; i ++){ TrainRecord[] neighbors = findKNearestNeighbors(trainingSet, testingSet[i], K, metric); int classLabel = classify(neighbors); testingSet[i].predictedLabel = classLabel; //assign the predicted label to TestRecord } //calculate the accuracy int correctPrediction = 0; for(int j = 0; j < numOfTestingRecord; j ++){ if(testingSet[j].predictedLabel == testingSet[j].classLabel) correctPrediction ++; } //Output a file containing predicted labels for TestRecords String predictPath = FileManager.outputFile(testingSet, trainingFile); System.out.println("The prediction file is stored in "+predictPath); System.out.println("The accuracy is "+((double)correctPrediction / numOfTestingRecord)*100+"%"); //print the total execution time final long endTime = System.currentTimeMillis(); System.out.println("Total excution time: "+(endTime - startTime) / (double)1000 +" seconds."); } catch (IOException e) { e.printStackTrace(); } } // Find K nearest neighbors of testRecord within trainingSet static TrainRecord[] findKNearestNeighbors(TrainRecord[] trainingSet, TestRecord testRecord,int K, Metric metric){ int NumOfTrainingSet = trainingSet.length; assert K <= NumOfTrainingSet : "K is lager than the length of trainingSet!"; //Update KNN: take the case when testRecord has multiple neighbors with the same distance into consideration //Solution: Update the size of container holding the neighbors TrainRecord[] neighbors = new TrainRecord[K]; //initialization, put the first K trainRecords into the above arrayList int index; for(index = 0; index < K; index++){ trainingSet[index].distance = metric.getDistance(trainingSet[index], testRecord); neighbors[index] = trainingSet[index]; } //go through the remaining records in the trainingSet to find K nearest neighbors for(index = K; index < NumOfTrainingSet; index ++){ trainingSet[index].distance = metric.getDistance(trainingSet[index], testRecord); //get the index of the neighbor with the largest distance to testRecord int maxIndex = 0; for(int i = 1; i < K; i ++){ if(neighbors[i].distance > neighbors[maxIndex].distance) maxIndex = i; } //add the current trainingSet[index] into neighbors if applicable if(neighbors[maxIndex].distance > trainingSet[index].distance) neighbors[maxIndex] = trainingSet[index]; } return neighbors; } // Get the class label by using neighbors static int classify(TrainRecord[] neighbors){ //construct a HashMap to store <classLabel, weight> HashMap<Integer, Double> map = new HashMap<Integer, Double>(); int num = neighbors.length; for(int index = 0;index < num; index ++){ TrainRecord temp = neighbors[index]; int key = temp.classLabel; //if this classLabel does not exist in the HashMap, put <key, 1/(temp.distance)> into the HashMap if(!map.containsKey(key)) map.put(key, 1 / temp.distance); //else, update the HashMap by adding the weight associating with that key else{ double value = map.get(key); value += 1 / temp.distance; map.put(key, value); } } //Find the most likely label double maxSimilarity = 0; int returnLabel = -1; Set<Integer> labelSet = map.keySet(); Iterator<Integer> it = labelSet.iterator(); //go through the HashMap by using keys //and find the key with the highest weights while(it.hasNext()){ int label = it.next(); double value = map.get(label); if(value > maxSimilarity){ maxSimilarity = value; returnLabel = label; } } return returnLabel; } static String extractGroupName(String filePath){ StringBuilder groupName = new StringBuilder(); for(int i = 15; i < filePath.length(); i ++){ if(filePath.charAt(i) != '_') groupName.append(filePath.charAt(i)); else break; } return groupName.toString(); }}