00001 #ifndef WDK_H
00002 #define WDK_H
00003
00004
00173 #include <stdexcept>
00174 #include <iostream>
00175 #include <sstream>
00176 #include <vector>
00177 #include <list>
00178 #include <map>
00179 #include <algorithm>
00180 #include <numeric>
00181 #include <cassert>
00182 #include <cmath>
00183 #include <cctype>
00184
00185 using namespace std;
00186
00187
00188 typedef float RealType;
00189 typedef long int IndexConcreteType;
00190 typedef int ValueConcreteType;
00191
00192
00193 enum KernelType {DOT_PRODUCT_TYPE=0,HISTOGRAM_INTERSECTION_TYPE};
00198 class UserKernelParametersClass
00199 {
00200 public:
00201 UserKernelParametersClass():
00202 mKernelType(HISTOGRAM_INTERSECTION_TYPE),
00203 mExpGamma(0),
00204 mPolyDegree(0),
00205 mPolyOffset(1),
00206 mPolyCoefficient(1),
00207 mTotalNormalization(true),
00208 mProductNormalization(true),
00209 mAttributeNormalization(true)
00210 {}
00211
00213 void Input(const char* aString) throw(exception)
00214 {
00215 vector<string> options;
00216 stringstream ss;
00217 ss<<(string)aString;
00218 while (!ss.eof())
00219 {
00220 string data;
00221 ss>>data;
00222 if (data!="") options.push_back(data);
00223 }
00224 for (vector<string>::iterator it=options.begin();it!=options.end();++it)
00225 {
00226 if ((*it)=="-Kd") mPolyDegree=atof((*(++it)).c_str());
00227 else if ((*it)=="-Kr") mPolyOffset=atof((*(++it)).c_str());
00228 else if ((*it)=="-Ks") mPolyCoefficient=atof((*(++it)).c_str());
00229 else if ((*it)=="-Kg") mExpGamma=atof((*(++it)).c_str());
00230 else if ((*it)=="-KDot") mKernelType=DOT_PRODUCT_TYPE;
00231 else if ((*it)=="-KHistInt") mKernelType=HISTOGRAM_INTERSECTION_TYPE;
00232 else if ((*it)=="-KNoTotNorm") mTotalNormalization=false;
00233 else if ((*it)=="-KNoProdNorm") mProductNormalization=false;
00234 else if ((*it)=="-KNoAttrNorm") mAttributeNormalization=false;
00235 else throw(invalid_argument("Argument:"+(*it)+" is not valid"));
00236 }
00237 if(mExpGamma*mPolyDegree)
00238 throw(invalid_argument("Error: exponential and polinomial cannot be set active at the same time"));
00239 ComputeInversePolinomialNormalizationCoefficient();
00240 }
00245 float ComputeInversePolinomialNormalizationCoefficient()
00246 {
00247 mInversePolinomialNormalizationCoefficient=1/pow(mPolyCoefficient+mPolyOffset,mPolyDegree);
00248 }
00249 KernelType mKernelType;
00250 float mExpGamma;
00251 float mPolyDegree;
00252 float mPolyOffset;
00253 float mPolyCoefficient;
00254 bool mTotalNormalization;
00255 bool mProductNormalization;
00256 bool mAttributeNormalization;
00257 float mInversePolinomialNormalizationCoefficient;
00258 };
00259
00260 static UserKernelParametersClass user_kernel_param;
00266 template <class OutType, class InType>
00267 OutType stream_cast(const InType & t)
00268 {
00269 stringstream ss;
00270 ss << t;
00271 OutType result;
00272 ss >> result;
00273 return result;
00274 }
00275
00279 pair<string,string> Split(const string& aInput,char sep=':') throw(exception)
00280 {
00281 string::size_type pos=aInput.find(sep);
00282 if (pos==string::npos) throw domain_error("Non tokenizable element:"+aInput);
00283 return make_pair(aInput.substr(0,pos),aInput.substr(pos+1,aInput.size()-(pos+1)));
00284 }
00285
00289 bool IsNumber(const string& aInput)
00290 {
00291 for (unsigned i=0;i<aInput.size();++i)
00292 if (!isdigit(aInput[i])) return false;
00293 return true;
00294 }
00295
00303 template<typename IndexType, typename ValueType>
00304 struct LessPair{
00305 inline bool operator()(const pair<IndexType,ValueType>& A,const pair<IndexType,ValueType>& B)
00306 {
00307 return A.first<B.first;
00308 }
00309 };
00310
00318 template<typename IndexType, typename ValueType>
00319 struct EqualPair{
00320 EqualPair(const IndexType& aItem){mItem=aItem;}
00321 inline bool operator()(const pair<IndexType,ValueType>& A)
00322 {
00323 return A.first==mItem;
00324 }
00325 IndexType mItem;
00326 };
00327
00328
00329
00330
00331
00332 template<typename IndexType, typename ValueType> class SparseVectorClass;
00333 template<typename IndexType, typename ValueType> ostream& operator<< (ostream& ss, const SparseVectorClass<IndexType,ValueType>& data);
00334
00335
00339 template<typename IndexType, typename ValueType>
00340 class SparseVectorClass
00341 {
00342 friend ostream& operator<< <IndexType,ValueType>(ostream& ss, const SparseVectorClass<IndexType,ValueType>& data);
00343 public:
00344 SparseVectorClass(){}
00345
00347 void Insert(const IndexType& aIndex, const ValueType& aValue)
00348 {
00349 mData.push_back(make_pair(aIndex,aValue));
00350 }
00351
00353 void Insert(const pair<IndexType,ValueType>& aItem)
00354 {
00355 mData.push_back(aItem);
00356 }
00357
00359 int Size()const{return mData.size();}
00360 public:
00361 list<pair<IndexType,ValueType> > mData;
00362 };
00363
00364
00365 template<typename IndexType, typename ValueType>
00366 ostream& operator<< (ostream& out, const SparseVectorClass<IndexType,ValueType>& data)
00367 {
00368 typename list<pair<IndexType,ValueType> >::const_iterator it=data.mData.begin();
00369 for (;it!=data.mData.end();++it)
00370 out<<" "<<it->second;
00371 return out;
00372 }
00373
00374
00375 template<>
00376 ostream& operator<< <IndexConcreteType,ValueConcreteType>(ostream& out, const SparseVectorClass<IndexConcreteType,ValueConcreteType>& data)
00377 {
00378 list<pair<IndexConcreteType,ValueConcreteType> >::const_iterator it=data.mData.begin();
00379 for (;it!=data.mData.end();++it)
00380 out<<" "<<it->first<<":"<<it->second;
00381 return out;
00382 }
00383
00384
00385
00386
00391 template<typename IndexType, typename ValueType>
00392 class DotSparseVectorClass:public SparseVectorClass<IndexType,ValueType>
00393 {
00394 public:
00395 DotSparseVectorClass(){}
00396
00398 void Sort()
00399 {
00400 this->mData.sort(LessPair<IndexType,ValueType>());
00401 }
00402
00405 RealType operator*(const DotSparseVectorClass<IndexType,ValueType>& aInstance)const
00406 {
00407 RealType result=0;
00408 typename list<pair<IndexType,ValueType> >::const_iterator it,jt;
00409 it=this->mData.begin();
00410 jt=aInstance.mData.begin();
00411 while(it!=this->mData.end() && jt!=aInstance.mData.end())
00412 {
00413 if (it->first==jt->first)
00414 {
00415 result+=it->second * jt->second;
00416 ++it;
00417 ++jt;
00418 }
00419 else if (it->first<jt->first) ++it;
00420 else
00421 ++jt;
00422 }
00423 return result;
00424 }
00425 };
00426
00427
00428
00429
00434 template<typename IndexType, typename ValueType>
00435 class HistogramIntersectionOptimizedSparseVectorClass:public SparseVectorClass<IndexType,ValueType>
00436 {
00437 friend ostream& operator<<(ostream& out, const HistogramIntersectionOptimizedSparseVectorClass& data)
00438 {
00439 out<<" dim:"<<data.mData.size();
00440 for (unsigned i=0;i<data.mData.size();++i)
00441 {
00442 out<<" bin:"<<data.mData[i].first;
00443 out<<" dim:"<<data.mData[i].second.size();
00444 for (unsigned j=0;j<data.mData[i].second.size();++j)
00445 out<<" "<<data.mData[i].second[j];
00446 }
00447 return out;
00448 }
00449
00450 public:
00457 void Insert(const IndexType& aIndex, const ValueType& aValue)
00458 {
00459
00460 EqualPair<IndexType,vector<ValueType> > equal(aIndex);
00461 typename vector<pair<IndexType,vector<ValueType> > >::iterator it=find_if(mData.begin(),mData.end(),equal);
00462
00463 if (it!=mData.end())
00464 {
00465
00466 it->second.push_back(aValue);
00467 }
00468
00469 else
00470 {
00471 vector<ValueType> tmp_vec;
00472 tmp_vec.push_back(aValue);
00473 mData.push_back(make_pair(aIndex,tmp_vec));
00474 }
00475 }
00476 void Insert(const pair<IndexType,ValueType>& aItem)
00477 {
00478 Insert(aItem.first,aItem.second);
00479 }
00481 void Sort()
00482 {
00483 sort(mData.begin(),mData.end(),LessPair<IndexType,vector<ValueType> >());
00484 for (typename vector<pair<IndexType,vector<ValueType> > >::iterator it=mData.begin();it!=mData.end();++it)
00485 {
00486 sort(it->second.begin(),it->second.end());
00487 }
00488 }
00491 RealType operator*(const HistogramIntersectionOptimizedSparseVectorClass<IndexType,ValueType>& aInstance)const
00492 {
00493 typename vector<pair<IndexType,vector<ValueType> > >::const_iterator it=mData.begin();
00494 typename vector<pair<IndexType,vector<ValueType> > >::const_iterator jt=aInstance.mData.begin();
00495
00496 ValueType result=0;
00497 while(it!=mData.end() && jt!=aInstance.mData.end())
00498 {
00499 if (it->first==jt->first)
00500 {
00501 result+=OptimizedHistogramIntersection(it->second,jt->second);
00502 it++;
00503 jt++;
00504 }
00505 else if (it->first < jt->first) it++;
00506 else
00507 jt++;
00508 }
00509 return result;
00510 }
00514 ValueType OptimizedHistogramIntersectionBaseVersion(const vector<ValueType>& aVecA, const vector<ValueType>& aVecB)const
00515 {
00516 ValueType result=0;
00517 for (unsigned i=0;i<aVecA.size();++i)
00518 {
00519 unsigned j;
00520 for (j=0;j<aVecB.size() && aVecB[j]<=aVecA[i];++j)
00521 result+=aVecB[j];
00522 result+=aVecA[i]*(aVecB.size()-j);
00523 }
00524 return result;
00525 }
00528 ValueType OptimizedHistogramIntersection(const vector<ValueType>& aVecA, const vector<ValueType>& aVecB)const
00529 {
00530 ValueType result=0;
00531 ValueType cumulative=0;
00532 unsigned i=0,j=0;
00533 while (i<aVecA.size() && j<aVecB.size())
00534 {
00535 if (aVecA[i]<aVecB[j])
00536 {
00537 result+=aVecA[i]*(aVecB.size()-j)+cumulative;
00538 i++;
00539 }
00540 else if (aVecA[i]>=aVecB[j])
00541 {
00542 cumulative+=aVecB[j];
00543 j++;
00544 }
00545 }
00546
00547 result+=(aVecA.size()-i)*cumulative;
00548 return result;
00549 }
00550 public:
00551 vector<pair<IndexType,vector<ValueType> > > mData;
00566 };
00567
00568
00569
00574 class AttributeClass
00575 {
00576 friend ostream& operator<<(ostream& out, const AttributeClass& data)
00577 {
00578 out<<"attribute:"<<data.mAttributeType<<" ";
00579 out<<"dim:"<<data.mValue.Size();
00580 out<<data.mValue;
00581 out<<data.mOptimizedValue;
00582 return out;
00583 }
00584 friend istream& operator>>(istream& in, AttributeClass& data) throw(exception)
00585 {
00586 string token;
00587 in>>token;pair<string,string> key_value_token=Split(token);
00588 if(key_value_token.first!="attribute") throw domain_error("Unexpected element:"+key_value_token.first);
00589
00590 data.mAttributeType=stream_cast<IndexConcreteType>(key_value_token.second);
00591
00592 in>>token;key_value_token=Split(token);
00593 if (key_value_token.first!="dim") throw domain_error("Unexpected element:"+key_value_token.first);
00594 if (!IsNumber(key_value_token.second)) throw domain_error("Expected a number:"+key_value_token.second);
00595 int dim=stream_cast<int>(key_value_token.second);
00596
00597 for (int i=0;i<dim;i++)
00598 {
00599 IndexConcreteType index;
00600 ValueConcreteType value;
00601 in>>token;pair<string,string> key_value_token=Split(token);
00602 if (!IsNumber(key_value_token.first)) throw domain_error("Expected a number:"+key_value_token.first);
00603 if (!IsNumber(key_value_token.second)) throw domain_error("Expected a number:"+key_value_token.second);
00604 index=stream_cast<IndexConcreteType>(key_value_token.first);
00605 value=stream_cast<ValueConcreteType>(key_value_token.second);
00606 data.mValue.Insert(index,value);
00607 data.mOptimizedValue.Insert(index,value);
00608 }
00609 return in;
00610 }
00611 public:
00612 AttributeClass():
00613 mAttributeType(0),
00614 mNorm(0),
00615 mInverseSquaredNorm(0){}
00620 void Merge(AttributeClass& aAttribute) throw(exception)
00621 {
00622 switch(user_kernel_param.mKernelType)
00623 {
00624 case DOT_PRODUCT_TYPE:
00625 MergeDotProduct(aAttribute);
00626 break;
00627 case HISTOGRAM_INTERSECTION_TYPE:
00628 MergeHistogramIntersection(aAttribute);
00629 break;
00630 default: throw(invalid_argument("ERROR:Unknown kernel"));
00631 }
00632 }
00635 RealType operator*(const AttributeClass& aInstance)const throw(exception)
00636 {
00637 if (user_kernel_param.mAttributeNormalization)
00638 {
00639
00640
00641
00642
00643 if (mNorm==0) ComputeNormalizationFactor();
00644 if (aInstance.mNorm==0) aInstance.ComputeNormalizationFactor();
00645 switch(user_kernel_param.mKernelType)
00646 {
00647 case DOT_PRODUCT_TYPE:
00648 return (mValue*aInstance.mValue)*mInverseSquaredNorm*aInstance.mInverseSquaredNorm;
00649 case HISTOGRAM_INTERSECTION_TYPE:
00650 return (mOptimizedValue*aInstance.mOptimizedValue)*mInverseSquaredNorm*aInstance.mInverseSquaredNorm;
00651 default: throw(invalid_argument("ERROR:Unknown kernel"));
00652 }
00653 }
00654 else
00655 {
00656 switch(user_kernel_param.mKernelType)
00657 {
00658 case DOT_PRODUCT_TYPE:
00659 return (mValue*aInstance.mValue);
00660 case HISTOGRAM_INTERSECTION_TYPE:
00661 return (mOptimizedValue*aInstance.mOptimizedValue);
00662 default: throw(invalid_argument("ERROR:Unknown kernel"));
00663 }
00664 }
00665 }
00666 protected:
00668 void MergeDotProduct(AttributeClass& aAttribute)
00669 {
00670
00671 list<pair<IndexConcreteType,ValueConcreteType> >::iterator it,jt;
00672 it=mValue.mData.begin();
00673 jt=aAttribute.mValue.mData.begin();
00674
00675 while(it!=mValue.mData.end() && jt!=aAttribute.mValue.mData.end())
00676 {
00677 if (it->first==jt->first)
00678 {
00679 it->second+=jt->second;
00680 ++it;
00681 ++jt;
00682 }
00683 else if (it->first < jt->first) ++it;
00684
00685 else
00686 {
00687 mValue.Insert(*jt);
00688 ++jt;
00689 }
00690 }
00691
00692 while (jt!=aAttribute.mValue.mData.end())
00693 {
00694 mValue.Insert(*jt);
00695 ++jt;
00696 }
00697
00698 mValue.Sort();
00699 }
00702 void MergeHistogramIntersection(AttributeClass& aAttribute)
00703 {
00704 for (list<pair<IndexConcreteType,ValueConcreteType> >::iterator it=aAttribute.mValue.mData.begin();it!=aAttribute.mValue.mData.end();++it)
00705 mOptimizedValue.Insert(it->first,it->second);
00706
00707
00708 mOptimizedValue.Sort();
00709 }
00712 void ComputeNormalizationFactor()const throw(exception)
00713 {
00714 switch(user_kernel_param.mKernelType)
00715 {
00716 case DOT_PRODUCT_TYPE:
00717 mNorm=(mValue*mValue);
00718 mInverseSquaredNorm=1/sqrt(mNorm);
00719 break;
00720 case HISTOGRAM_INTERSECTION_TYPE:
00721 mNorm=(mOptimizedValue*mOptimizedValue);
00722 mInverseSquaredNorm=1/sqrt(mNorm);
00723 break;
00724 default:
00725 throw(invalid_argument("ERROR:Unknown kernel"));
00726 }
00727 }
00728
00729 public:
00730 IndexConcreteType mAttributeType;
00731 mutable RealType mNorm;
00732 mutable RealType mInverseSquaredNorm;
00733 DotSparseVectorClass<IndexConcreteType,ValueConcreteType> mValue;
00734 HistogramIntersectionOptimizedSparseVectorClass<IndexConcreteType,ValueConcreteType> mOptimizedValue;
00735 };
00736
00737
00741 class PartClass
00742 {
00743 public:
00744 friend ostream& operator<<(ostream& out, const PartClass& data)
00745 {
00746 out<<"part:"<<data.mPartType<<" ";
00747 out<<"dim:"<<data.mAttr.Size();
00748 out<<data.mAttr;
00749 return out;
00750 }
00751 friend istream& operator>>(istream& in, PartClass& data) throw(exception)
00752 {
00753 string token;
00754 in>>token;pair<string,string> key_value_token=Split(token);
00755 if (key_value_token.first!="part") throw domain_error("Unexpected element:"+key_value_token.first);
00756 data.mPartType=stream_cast<IndexConcreteType>(key_value_token.second);
00757 in>>token;key_value_token=Split(token);
00758 if (key_value_token.first!="dim") throw domain_error("Unexpected element:"+key_value_token.first);
00759 if (!IsNumber(key_value_token.second)) throw domain_error("Expected a number:"+key_value_token.second);
00760 int dim=stream_cast<int>(key_value_token.second);
00761 for (int i=0;i<dim;i++)
00762 {
00763 AttributeClass attribute;
00764 in>>attribute;
00765 data.mAttr.Insert(attribute.mAttributeType,attribute);
00766 }
00767 data.mAttr.Sort();
00768 return in;
00769 }
00770 public:
00771 PartClass(){}
00774 void Merge(PartClass& aPart)
00775 {
00776
00777 list<pair<IndexConcreteType,AttributeClass> >::iterator it,jt;
00778 it=mAttr.mData.begin();
00779 jt=aPart.mAttr.mData.begin();
00780
00781 while(it!=mAttr.mData.end() && jt!=aPart.mAttr.mData.end())
00782 {
00783 if (it->first==jt->first)
00784 {
00785 it->second.Merge(jt->second);
00786 ++it;
00787 ++jt;
00788 }
00789 else if (it->first < jt->first) ++it;
00790
00791 else
00792 {
00793 mAttr.Insert(*jt);
00794 ++jt;
00795 }
00796 }
00797
00798 while (jt!=aPart.mAttr.mData.end())
00799 {
00800 mAttr.Insert(*jt);
00801 ++jt;
00802 }
00803
00804 mAttr.Sort();
00805 }
00808 RealType operator*(const PartClass& aInstance)const
00809 {
00810 return mAttr*aInstance.mAttr;
00811 }
00812 public:
00813 IndexConcreteType mPartType;
00814 DotSparseVectorClass<IndexConcreteType,AttributeClass> mAttr;
00815 };
00816
00817
00818
00822 class WDKDataClass
00823 {
00824 public:
00825 friend ostream& operator<<(ostream& out, const WDKDataClass& data)
00826 {
00827 out<<"dim:"<<data.mPart.Size();
00828 out<<data.mPart;
00829 return out;
00830 }
00831 friend istream& operator>>(istream& in, WDKDataClass& data) throw(exception)
00832 {
00833 string token;
00834 int dim;
00835 in>>token;pair<string,string> key_value_token=Split(token);
00836 if (key_value_token.first!="dim") throw domain_error("Unexpected element:"+key_value_token.first);
00837 if (!IsNumber(key_value_token.second)) throw domain_error("Expected a number:"+key_value_token.second);
00838 dim=stream_cast<int>(key_value_token.second);
00839 for (int i=0;i<dim;i++)
00840 {
00841 PartClass part;
00842 in>>part;
00843 data.mPart.Insert(part.mPartType,part);
00844 }
00845 data.Optimize();
00846 return in;
00847 }
00848 public:
00849 WDKDataClass():mNorm(0),mInverseSquaredNorm(0){}
00854 WDKDataClass(const char* aS):mNorm(1),mInverseSquaredNorm(1)
00855 {
00856 mSerialized=(string)(aS);
00857 stringstream ss(aS);
00858 ss>>(*this);
00859 ComputeNormalizationFactor();
00860 }
00863 string Serialize()
00864 {
00865 return mSerialized;
00866 }
00869 RealType operator*(const WDKDataClass& aInstance)const
00870 {
00871 if (mNorm==0 || aInstance.mNorm==0) return 0;
00872 if (user_kernel_param.mProductNormalization)
00873 return Product(aInstance)*mInverseSquaredNorm*aInstance.mInverseSquaredNorm;
00874 else
00875 return Product(aInstance);
00876 }
00877 protected:
00880 void Optimize()
00881 {
00882
00883 mPart.Sort();
00884
00885
00886
00887 for (list<pair<IndexConcreteType,PartClass> >::iterator it=mPart.mData.begin();it!=mPart.mData.end();)
00888 {
00889 list<pair<IndexConcreteType,PartClass> >::iterator jt=it;
00890 jt++;
00891 if (jt!=mPart.mData.end())
00892 {
00893
00894 if (it->first==jt->first)
00895 {
00896 it->second.Merge(jt->second);
00897 mPart.mData.erase(jt);
00898 }
00899
00900 else ++it;
00901 }
00902 else ++it;
00903 }
00904 }
00907 void ComputeNormalizationFactor()
00908 {
00909 mNorm=Product(*this);
00910 mInverseSquaredNorm=1/sqrt(mNorm);
00911 }
00912 protected:
00914 RealType Product(const WDKDataClass& aInstance)const
00915 {
00916 return mPart*aInstance.mPart;
00917 }
00918 public:
00919 RealType mNorm;
00920 RealType mInverseSquaredNorm;
00921 protected:
00922 DotSparseVectorClass<IndexConcreteType,PartClass> mPart;
00923 string mSerialized;
00924 };
00925
00929 class UserKernelClass
00930 {
00931 public:
00933 RealType K(const WDKDataClass& aX,const WDKDataClass& aZ, const UserKernelParametersClass& aParam)
00934 {
00935 try
00936 {
00937 if (aParam.mTotalNormalization)
00938 {
00939 if (aParam.mPolyDegree!=0)
00940 {
00941 if (aParam.mProductNormalization) return k(aX,aZ,aParam)*aParam.mInversePolinomialNormalizationCoefficient;
00942 else return k(aX,aZ,aParam)/sqrt(k(aX,aX,aParam)*k(aZ,aZ,aParam));
00943 }
00944 else
00945 return k(aX,aZ,aParam);
00946 }
00947 else return k(aX,aZ,aParam);
00948 }
00949 catch(exception e)
00950 {
00951 cerr<<e.what()<<endl;
00952 exit(1);
00953 }
00954 }
00956 RealType k(const WDKDataClass& aX,const WDKDataClass& aZ, const UserKernelParametersClass& aParam)
00957 {
00958 if (aParam.mPolyDegree!=0)
00959 {
00960 return pow(aParam.mPolyCoefficient*(aX*aZ)+aParam.mPolyOffset,aParam.mPolyDegree);
00961 }
00962 else if (aParam.mExpGamma)
00963 {
00964 if (aParam.mProductNormalization)
00965 return exp(-aParam.mExpGamma*(2 -2*(aX*aZ)));
00966 else
00967 return exp(-aParam.mExpGamma*(aX.mNorm -2*(aX*aZ) + aZ.mNorm));
00968 }
00969 else return (aX*aZ);
00970 }
00971 };
00972
00973 #endif