lucene/solr/contrib/ltr/example/libsvm_formatter.py

127 lines
6.0 KiB
Python
Raw Normal View History

from subprocess import call
import os
PAIRWISE_THRESHOLD = 1.e-1
FEATURE_DIFF_THRESHOLD = 1.e-6
class LibSvmFormatter:
def processQueryDocFeatureVector(self,docClickInfo,trainingFile):
'''Expects as input a sorted by queries list or generator that provides the context
for each query in a tuple composed of: (query , docId , relevance , source , featureVector).
The list of documents that are part of the same query will generate comparisons
against each other for training. '''
with open(trainingFile,"w") as output:
self.featureNameToId = {}
self.featureIdToName = {}
self.curFeatIndex = 1;
curListOfFv = []
curQueryAndSource = ""
for query,docId,relevance,source,featureVector in docClickInfo:
if curQueryAndSource != query + source:
#Time to flush out all the pairs
_writeRankSVMPairs(curListOfFv,output);
curListOfFv = []
curQueryAndSource = query + source
curListOfFv.append((relevance,self._makeFeaturesMap(featureVector)))
_writeRankSVMPairs(curListOfFv,output); #This catches the last list of comparisons
def _makeFeaturesMap(self,featureVector):
'''expects a list of strings with "feature name":"feature value" pairs. Outputs a map of map[key] = value.
Where key is now an integer. libSVM requires the key to be an integer but not all libraries have
this requirement.'''
features = {}
for keyValuePairStr in featureVector:
featName,featValue = keyValuePairStr.split("=");
features[self._getFeatureId(featName)] = float(featValue);
return features
def _getFeatureId(self,key):
if key not in self.featureNameToId:
self.featureNameToId[key] = self.curFeatIndex;
self.featureIdToName[self.curFeatIndex] = key;
self.curFeatIndex += 1;
return self.featureNameToId[key];
def convertLibSvmModelToLtrModel(self,libSvmModelLocation,outputFile,modelName,featureStoreName):
with open(libSvmModelLocation, 'r') as inFile:
with open(outputFile,'w') as convertedOutFile:
# TODO: use json module instead of direct write
convertedOutFile.write('{\n\t"class":"org.apache.solr.ltr.model.LinearModel",\n')
convertedOutFile.write('\t"store": "' + str(featureStoreName) + '",\n')
convertedOutFile.write('\t"name": "' + str(modelName) + '",\n')
convertedOutFile.write('\t"features": [\n')
isFirst = True;
for featKey in self.featureNameToId.keys():
convertedOutFile.write('\t\t{ "name":"' + featKey + '"}' if isFirst else ',\n\t\t{ "name":"' + featKey + '"}' );
isFirst = False;
convertedOutFile.write("\n\t],\n");
convertedOutFile.write('\t"params": {\n\t\t"weights": {\n');
startReading = False
isFirst = True
counter = 1
for line in inFile:
if startReading:
newParamVal = float(line.strip())
if not isFirst:
convertedOutFile.write(',\n\t\t\t"' + self.featureIdToName[counter] + '":' + str(newParamVal))
else:
convertedOutFile.write('\t\t\t"' + self.featureIdToName[counter] + '":' + str(newParamVal))
isFirst = False
counter += 1
elif line.strip() == 'w':
startReading = True
convertedOutFile.write('\n\t\t}\n\t}\n}')
def _writeRankSVMPairs(listOfFeatures,output):
'''Given a list of (relevance, {Features Map}) where the list represents
a set of documents to be compared, this calculates all pairs and
writes the Feature Vectors in a format compatible with libSVM.
Ex: listOfFeatures = [
#(relevance, {feature1:value, featureN:value})
(4, {1:0.9, 2:0.9, 3:0.1})
(3, {1:0.7, 2:0.9, 3:0.2})
(1, {1:0.1, 2:0.9, 6:0.1})
]
'''
for d1 in range(0,len(listOfFeatures)):
for d2 in range(d1+1,len(listOfFeatures)):
doc1,doc2 = listOfFeatures[d1], listOfFeatures[d2]
fv1,fv2 = doc1[1],doc2[1]
d1Relevance, d2Relevance = float(doc1[0]),float(doc2[0])
if d1Relevance - d2Relevance > PAIRWISE_THRESHOLD:#d1Relevance > d2Relevance
outputLibSvmLine("+1",subtractFvMap(fv1,fv2),output);
outputLibSvmLine("-1",subtractFvMap(fv2,fv1),output);
elif d1Relevance - d2Relevance < -PAIRWISE_THRESHOLD: #d1Relevance < d2Relevance:
outputLibSvmLine("+1",subtractFvMap(fv2,fv1),output);
outputLibSvmLine("-1",subtractFvMap(fv1,fv2),output);
else: #Must be approximately equal relevance, in which case this is a useless signal and we should skip
continue;
def subtractFvMap(fv1,fv2):
'''returns the fv from fv1 - fv2'''
retFv = fv1.copy();
for featInd in fv2.keys():
subVal = 0.0;
if featInd in fv1:
subVal = fv1[featInd] - fv2[featInd]
else:
subVal = -fv2[featInd]
if abs(subVal) > FEATURE_DIFF_THRESHOLD: #This ensures everything is in sparse format, and removes useless signals
retFv[featInd] = subVal;
else:
retFv.pop(featInd, None)
return retFv;
def outputLibSvmLine(sign,fvMap,outputFile):
outputFile.write(sign)
for feat in fvMap.keys():
outputFile.write(" " + str(feat) + ":" + str(fvMap[feat]));
outputFile.write("\n")
def trainLibSvm(libraryLocation,libraryOptions,trainingFileName,trainedModelFileName):
if os.path.isfile(libraryLocation):
call([libraryLocation, libraryOptions, trainingFileName, trainedModelFileName])
else:
raise Exception("NO LIBRARY FOUND: " + libraryLocation);