Main Page | Class Hierarchy | Class List | File List | Class Members

wdk.h

00001 #ifndef WDK_H
00002 #define WDK_H
00003 
00004 
00173 #include <stdexcept> 
00174 #include <iostream>
00175 #include <sstream>
00176 #include <vector>
00177 #include <list>
00178 #include <map>
00179 #include <algorithm>
00180 #include <numeric>
00181 #include <cassert>
00182 #include <cmath>
00183 #include <cctype>
00184 
00185 using namespace std;
00186 
00187 
00188 typedef float RealType;
00189 typedef long int IndexConcreteType;
00190 typedef int ValueConcreteType;
00191 
00192 
00193 enum KernelType {DOT_PRODUCT_TYPE=0,HISTOGRAM_INTERSECTION_TYPE};
00198 class UserKernelParametersClass
00199 {
00200  public:
00201   UserKernelParametersClass():
00202     mKernelType(HISTOGRAM_INTERSECTION_TYPE),
00203     mExpGamma(0),
00204     mPolyDegree(0),
00205     mPolyOffset(1),
00206     mPolyCoefficient(1),
00207     mTotalNormalization(true),
00208     mProductNormalization(true),
00209     mAttributeNormalization(true)
00210     {}
00211 
00213   void Input(const char* aString) throw(exception)
00214     {
00215       vector<string> options;
00216       stringstream ss;
00217       ss<<(string)aString;
00218       while (!ss.eof())
00219         {
00220           string data;
00221           ss>>data;
00222           if (data!="") options.push_back(data);
00223         }
00224       for (vector<string>::iterator it=options.begin();it!=options.end();++it)
00225         {
00226           if ((*it)=="-Kd") mPolyDegree=atof((*(++it)).c_str());
00227           else if ((*it)=="-Kr") mPolyOffset=atof((*(++it)).c_str());
00228           else if ((*it)=="-Ks") mPolyCoefficient=atof((*(++it)).c_str());
00229           else if ((*it)=="-Kg") mExpGamma=atof((*(++it)).c_str());
00230           else if ((*it)=="-KDot") mKernelType=DOT_PRODUCT_TYPE;
00231           else if ((*it)=="-KHistInt") mKernelType=HISTOGRAM_INTERSECTION_TYPE;
00232           else if ((*it)=="-KNoTotNorm") mTotalNormalization=false;
00233           else if ((*it)=="-KNoProdNorm") mProductNormalization=false;
00234           else if ((*it)=="-KNoAttrNorm") mAttributeNormalization=false;
00235           else throw(invalid_argument("Argument:"+(*it)+" is not valid"));
00236         }
00237       if(mExpGamma*mPolyDegree)//Mutual exclusion between polinomial and exponential kernel
00238         throw(invalid_argument("Error: exponential and polinomial cannot be set active at the same time"));
00239       ComputeInversePolinomialNormalizationCoefficient();
00240     }
00245   float ComputeInversePolinomialNormalizationCoefficient()
00246     {
00247       mInversePolinomialNormalizationCoefficient=1/pow(mPolyCoefficient+mPolyOffset,mPolyDegree);
00248     }
00249   KernelType mKernelType; 
00250   float mExpGamma; 
00251   float mPolyDegree; 
00252   float mPolyOffset; 
00253   float mPolyCoefficient; 
00254   bool mTotalNormalization; 
00255   bool mProductNormalization; 
00256   bool mAttributeNormalization; 
00257   float mInversePolinomialNormalizationCoefficient; 
00258 };
00259 
00260 static UserKernelParametersClass user_kernel_param; 
00266 template <class OutType, class InType>
00267 OutType stream_cast(const InType & t)
00268 {
00269  stringstream ss;
00270  ss << t; // first insert value to stream
00271  OutType result; // value will be converted to OutType
00272  ss >> result; // write value to result
00273  return result;
00274 }
00275 
00279 pair<string,string> Split(const string& aInput,char sep=':') throw(exception)
00280 {
00281   string::size_type pos=aInput.find(sep);
00282   if (pos==string::npos) throw domain_error("Non tokenizable element:"+aInput);
00283   return make_pair(aInput.substr(0,pos),aInput.substr(pos+1,aInput.size()-(pos+1)));
00284 }
00285 
00289 bool IsNumber(const string& aInput)
00290 {
00291   for (unsigned i=0;i<aInput.size();++i)
00292     if (!isdigit(aInput[i])) return false;
00293   return true;
00294 }
00295 
00303 template<typename IndexType, typename ValueType>
00304   struct LessPair{
00305     inline bool operator()(const pair<IndexType,ValueType>& A,const pair<IndexType,ValueType>& B)
00306     {
00307       return A.first<B.first;
00308     }
00309   };
00310 
00318 template<typename IndexType, typename ValueType>
00319   struct EqualPair{
00320     EqualPair(const IndexType& aItem){mItem=aItem;}
00321     inline bool operator()(const pair<IndexType,ValueType>& A)
00322     {
00323       return A.first==mItem;
00324     }
00325     IndexType mItem;
00326   };
00327 
00328 //-----------------------------------------------------------------------------------------------
00329 
00330 //Note:friend declaration is a usage of function not a proper declaration
00331 //We first declare operator<<, then qualify it as friend, then define it and define its specialization
00332 template<typename IndexType, typename ValueType> class SparseVectorClass;
00333 template<typename IndexType, typename ValueType> ostream& operator<< (ostream& ss, const SparseVectorClass<IndexType,ValueType>& data);
00334 
00335 
00339 template<typename IndexType, typename ValueType>
00340 class SparseVectorClass
00341 {
00342   friend ostream& operator<< <IndexType,ValueType>(ostream& ss, const SparseVectorClass<IndexType,ValueType>& data);
00343  public:
00344   SparseVectorClass(){}
00345 
00347   void Insert(const IndexType& aIndex, const ValueType& aValue)
00348     {
00349       mData.push_back(make_pair(aIndex,aValue));
00350     }
00351 
00353   void Insert(const pair<IndexType,ValueType>& aItem)
00354     {
00355       mData.push_back(aItem);
00356     }
00357 
00359   int Size()const{return mData.size();}
00360  public:
00361   list<pair<IndexType,ValueType> > mData; 
00362 };
00363 
00364 //Note: generally we want to characterize the semantic of the sparse vector by prepending a description before its data, i.e. we do not want to output the index but only the value
00365 template<typename IndexType, typename ValueType>
00366 ostream& operator<< (ostream& out, const SparseVectorClass<IndexType,ValueType>& data)
00367 {
00368   typename list<pair<IndexType,ValueType> >::const_iterator it=data.mData.begin();
00369   for (;it!=data.mData.end();++it)
00370     out<<" "<<it->second;
00371   return out;
00372 }
00373 
00374 //Note: in the special case of an <int:int> sparse vector the output includes both the index and the value
00375 template<>
00376 ostream& operator<< <IndexConcreteType,ValueConcreteType>(ostream& out, const SparseVectorClass<IndexConcreteType,ValueConcreteType>& data)
00377     {
00378       list<pair<IndexConcreteType,ValueConcreteType> >::const_iterator it=data.mData.begin();
00379       for (;it!=data.mData.end();++it)
00380         out<<" "<<it->first<<":"<<it->second;
00381       return out;
00382     }
00383 
00384 
00385 //-----------------------------------------------------------------------------------------------
00386 
00391 template<typename IndexType, typename ValueType>
00392 class DotSparseVectorClass:public SparseVectorClass<IndexType,ValueType>
00393 {
00394  public:
00395   DotSparseVectorClass(){}
00396 
00398   void Sort()
00399     {
00400       this->mData.sort(LessPair<IndexType,ValueType>());
00401     }
00402 
00405   RealType operator*(const DotSparseVectorClass<IndexType,ValueType>& aInstance)const
00406     {
00407       RealType result=0;
00408       typename list<pair<IndexType,ValueType> >::const_iterator it,jt;
00409       it=this->mData.begin();
00410       jt=aInstance.mData.begin();
00411       while(it!=this->mData.end() && jt!=aInstance.mData.end())
00412         {
00413           if (it->first==jt->first)
00414             {
00415               result+=it->second * jt->second;
00416               ++it;
00417               ++jt;
00418             }
00419           else if (it->first<jt->first) ++it;
00420           else //(it->first==jt->first)
00421             ++jt;
00422         }
00423       return result;
00424     }
00425 };
00426 
00427 
00428 //-----------------------------------------------------------------------------------------------
00429 
00434 template<typename IndexType, typename ValueType>
00435 class HistogramIntersectionOptimizedSparseVectorClass:public SparseVectorClass<IndexType,ValueType>
00436 {
00437   friend ostream& operator<<(ostream& out, const HistogramIntersectionOptimizedSparseVectorClass& data)
00438     {
00439       out<<" dim:"<<data.mData.size();
00440       for (unsigned i=0;i<data.mData.size();++i)
00441         {
00442           out<<" bin:"<<data.mData[i].first;
00443           out<<" dim:"<<data.mData[i].second.size();
00444           for (unsigned j=0;j<data.mData[i].second.size();++j)
00445             out<<" "<<data.mData[i].second[j];
00446         }
00447       return out;
00448     }
00449   
00450  public:
00457   void Insert(const IndexType& aIndex, const ValueType& aValue)
00458     {
00459       //find index
00460       EqualPair<IndexType,vector<ValueType> > equal(aIndex);
00461       typename vector<pair<IndexType,vector<ValueType> > >::iterator it=find_if(mData.begin(),mData.end(),equal);
00462       //if it exists
00463       if (it!=mData.end())
00464         {
00465           //add value to corresponding vector
00466           it->second.push_back(aValue);
00467         }
00468       //else add new vector containing value
00469       else
00470         {
00471           vector<ValueType> tmp_vec;
00472           tmp_vec.push_back(aValue);
00473           mData.push_back(make_pair(aIndex,tmp_vec));
00474         }
00475     }
00476   void Insert(const pair<IndexType,ValueType>& aItem)
00477     {
00478       Insert(aItem.first,aItem.second);
00479     }
00481   void Sort()
00482     {
00483       sort(mData.begin(),mData.end(),LessPair<IndexType,vector<ValueType> >());
00484       for (typename vector<pair<IndexType,vector<ValueType> > >::iterator it=mData.begin();it!=mData.end();++it)
00485         {
00486           sort(it->second.begin(),it->second.end());
00487         }
00488     }
00491   RealType operator*(const HistogramIntersectionOptimizedSparseVectorClass<IndexType,ValueType>& aInstance)const
00492     {
00493       typename vector<pair<IndexType,vector<ValueType> > >::const_iterator it=mData.begin();
00494       typename vector<pair<IndexType,vector<ValueType> > >::const_iterator jt=aInstance.mData.begin();
00495 
00496       ValueType result=0;
00497       while(it!=mData.end() && jt!=aInstance.mData.end())
00498         {
00499           if (it->first==jt->first)
00500             {
00501               result+=OptimizedHistogramIntersection(it->second,jt->second);
00502               it++;
00503               jt++;
00504             }
00505           else if (it->first < jt->first) it++;
00506           else //it->first > jt->first
00507             jt++;
00508         }
00509       return result;
00510     }
00514   ValueType OptimizedHistogramIntersectionBaseVersion(const vector<ValueType>& aVecA, const vector<ValueType>& aVecB)const
00515     {
00516       ValueType result=0;
00517       for (unsigned i=0;i<aVecA.size();++i)
00518         {
00519           unsigned j;
00520           for (j=0;j<aVecB.size() && aVecB[j]<=aVecA[i];++j)
00521             result+=aVecB[j];
00522           result+=aVecA[i]*(aVecB.size()-j);
00523         }
00524       return result;
00525     }
00528   ValueType OptimizedHistogramIntersection(const vector<ValueType>& aVecA, const vector<ValueType>& aVecB)const
00529     {
00530       ValueType result=0;
00531       ValueType cumulative=0;
00532       unsigned i=0,j=0;
00533       while (i<aVecA.size() && j<aVecB.size())
00534         {
00535           if (aVecA[i]<aVecB[j])
00536             {
00537               result+=aVecA[i]*(aVecB.size()-j)+cumulative;
00538               i++;
00539             }
00540           else if (aVecA[i]>=aVecB[j])
00541             {
00542               cumulative+=aVecB[j];
00543               j++;
00544             }
00545         }
00546       //if there are still elements in aVecA (i.e. i<aVewcA.size()) but not in aVecB then it means that they are all greater and therefore:
00547       result+=(aVecA.size()-i)*cumulative;
00548       return result;
00549     }
00550  public:
00551   vector<pair<IndexType,vector<ValueType> > > mData;
00566 };
00567 
00568 //-----------------------------------------------------------------------------------------------
00569 
00574 class AttributeClass
00575 {
00576   friend ostream& operator<<(ostream& out, const AttributeClass& data)
00577     {
00578       out<<"attribute:"<<data.mAttributeType<<" ";
00579       out<<"dim:"<<data.mValue.Size();
00580       out<<data.mValue;
00581       out<<data.mOptimizedValue;
00582       return out;
00583     }
00584   friend istream& operator>>(istream& in, AttributeClass& data) throw(exception)
00585     {
00586       string token;
00587       in>>token;pair<string,string> key_value_token=Split(token);
00588       if(key_value_token.first!="attribute") throw domain_error("Unexpected element:"+key_value_token.first);
00589         
00590       data.mAttributeType=stream_cast<IndexConcreteType>(key_value_token.second);
00591 
00592       in>>token;key_value_token=Split(token);
00593       if (key_value_token.first!="dim") throw domain_error("Unexpected element:"+key_value_token.first);
00594       if (!IsNumber(key_value_token.second)) throw domain_error("Expected a number:"+key_value_token.second);
00595       int dim=stream_cast<int>(key_value_token.second);
00596 
00597       for (int i=0;i<dim;i++)
00598         {
00599           IndexConcreteType index;
00600           ValueConcreteType value;
00601           in>>token;pair<string,string> key_value_token=Split(token);
00602           if (!IsNumber(key_value_token.first)) throw domain_error("Expected a number:"+key_value_token.first);
00603           if (!IsNumber(key_value_token.second)) throw domain_error("Expected a number:"+key_value_token.second);
00604           index=stream_cast<IndexConcreteType>(key_value_token.first);
00605           value=stream_cast<ValueConcreteType>(key_value_token.second);
00606           data.mValue.Insert(index,value);
00607           data.mOptimizedValue.Insert(index,value);
00608         }
00609       return in;
00610     }
00611  public:
00612   AttributeClass():
00613     mAttributeType(0),
00614     mNorm(0),
00615     mInverseSquaredNorm(0){}
00620   void Merge(AttributeClass& aAttribute) throw(exception)
00621     {
00622       switch(user_kernel_param.mKernelType)
00623         {
00624         case DOT_PRODUCT_TYPE:
00625           MergeDotProduct(aAttribute);
00626           break;
00627         case HISTOGRAM_INTERSECTION_TYPE:
00628           MergeHistogramIntersection(aAttribute);
00629           break;
00630         default: throw(invalid_argument("ERROR:Unknown kernel"));
00631         }
00632     }
00635   RealType operator*(const AttributeClass& aInstance)const throw(exception)
00636     {
00637       if (user_kernel_param.mAttributeNormalization)
00638         {
00639           //Note: lazy normalization: this is needed since the
00640           //optimization procedure alters the attribute data types, so
00641           //that no normalization can be computed just after parsing
00642           //the data in input
00643           if (mNorm==0) ComputeNormalizationFactor();
00644           if (aInstance.mNorm==0) aInstance.ComputeNormalizationFactor();       
00645           switch(user_kernel_param.mKernelType)
00646             {
00647             case DOT_PRODUCT_TYPE:
00648               return (mValue*aInstance.mValue)*mInverseSquaredNorm*aInstance.mInverseSquaredNorm;
00649             case HISTOGRAM_INTERSECTION_TYPE:
00650               return (mOptimizedValue*aInstance.mOptimizedValue)*mInverseSquaredNorm*aInstance.mInverseSquaredNorm;
00651             default: throw(invalid_argument("ERROR:Unknown kernel"));
00652             }
00653         }
00654       else
00655         {
00656           switch(user_kernel_param.mKernelType)
00657             {
00658             case DOT_PRODUCT_TYPE:
00659               return (mValue*aInstance.mValue);
00660             case HISTOGRAM_INTERSECTION_TYPE:
00661               return (mOptimizedValue*aInstance.mOptimizedValue);
00662             default: throw(invalid_argument("ERROR:Unknown kernel"));
00663             }
00664         }
00665     }
00666  protected:
00668   void MergeDotProduct(AttributeClass& aAttribute)
00669     {
00670       //invoke add on matching histogram objects
00671       list<pair<IndexConcreteType,ValueConcreteType> >::iterator it,jt;
00672       it=mValue.mData.begin();
00673       jt=aAttribute.mValue.mData.begin();
00674 
00675       while(it!=mValue.mData.end() && jt!=aAttribute.mValue.mData.end())
00676         {
00677           if (it->first==jt->first)
00678             {
00679               it->second+=jt->second;
00680               ++it;
00681               ++jt;
00682             }
00683           else if (it->first < jt->first) ++it;
00684           //case when there is an element in aAttribute that is not present in current object
00685           else 
00686             {
00687               mValue.Insert(*jt);
00688               ++jt;
00689             }
00690         }
00691       //extremal cases: copy remaining aAttribute.mValue elements
00692       while (jt!=aAttribute.mValue.mData.end())
00693         {
00694           mValue.Insert(*jt);
00695           ++jt;
00696         }
00697       //sort for later use
00698       mValue.Sort();
00699     }
00702   void MergeHistogramIntersection(AttributeClass& aAttribute)
00703     {
00704       for (list<pair<IndexConcreteType,ValueConcreteType> >::iterator it=aAttribute.mValue.mData.begin();it!=aAttribute.mValue.mData.end();++it)
00705         mOptimizedValue.Insert(it->first,it->second);
00706      
00707       //sort for later use
00708       mOptimizedValue.Sort();
00709     }
00712   void ComputeNormalizationFactor()const throw(exception)
00713     {
00714       switch(user_kernel_param.mKernelType)
00715         {
00716         case DOT_PRODUCT_TYPE:
00717           mNorm=(mValue*mValue);
00718           mInverseSquaredNorm=1/sqrt(mNorm);
00719           break;
00720         case HISTOGRAM_INTERSECTION_TYPE:
00721           mNorm=(mOptimizedValue*mOptimizedValue);
00722           mInverseSquaredNorm=1/sqrt(mNorm);
00723           break;
00724         default:
00725           throw(invalid_argument("ERROR:Unknown kernel"));
00726         }
00727     }
00728 
00729  public:
00730   IndexConcreteType mAttributeType;
00731   mutable RealType mNorm;
00732   mutable RealType mInverseSquaredNorm;
00733   DotSparseVectorClass<IndexConcreteType,ValueConcreteType> mValue; 
00734   HistogramIntersectionOptimizedSparseVectorClass<IndexConcreteType,ValueConcreteType> mOptimizedValue;
00735 };
00736 
00737 //-----------------------------------------------------------------------------------------------
00741 class PartClass
00742 {
00743  public:
00744   friend ostream& operator<<(ostream& out, const PartClass& data)
00745     {
00746       out<<"part:"<<data.mPartType<<" ";
00747       out<<"dim:"<<data.mAttr.Size();
00748       out<<data.mAttr;
00749       return out;
00750     }
00751   friend istream& operator>>(istream& in, PartClass& data) throw(exception)
00752     {
00753       string token;
00754       in>>token;pair<string,string> key_value_token=Split(token);
00755       if (key_value_token.first!="part") throw domain_error("Unexpected element:"+key_value_token.first);
00756       data.mPartType=stream_cast<IndexConcreteType>(key_value_token.second);
00757       in>>token;key_value_token=Split(token);
00758       if (key_value_token.first!="dim") throw domain_error("Unexpected element:"+key_value_token.first);
00759       if (!IsNumber(key_value_token.second)) throw domain_error("Expected a number:"+key_value_token.second);
00760       int dim=stream_cast<int>(key_value_token.second);
00761       for (int i=0;i<dim;i++)
00762         {
00763           AttributeClass attribute;
00764           in>>attribute;
00765           data.mAttr.Insert(attribute.mAttributeType,attribute);
00766         }
00767       data.mAttr.Sort();
00768       return in;
00769     }
00770  public:
00771   PartClass(){}
00774   void Merge(PartClass& aPart)
00775     {
00776       //invoke add on matching attribute objects
00777       list<pair<IndexConcreteType,AttributeClass> >::iterator it,jt;
00778       it=mAttr.mData.begin();
00779       jt=aPart.mAttr.mData.begin();
00780 
00781       while(it!=mAttr.mData.end() && jt!=aPart.mAttr.mData.end())
00782         {
00783           if (it->first==jt->first)
00784             {
00785               it->second.Merge(jt->second);
00786               ++it;
00787               ++jt;
00788             }
00789           else if (it->first < jt->first) ++it;
00790           //case when there is an element in aPart that is not present in current object
00791           else 
00792             {
00793               mAttr.Insert(*jt);
00794               ++jt;
00795             }
00796         }
00797       //extremal cases: copy remaining aPart.mAttr elements
00798       while (jt!=aPart.mAttr.mData.end())
00799         {
00800           mAttr.Insert(*jt);
00801           ++jt;
00802         }
00803       //sort for later use
00804       mAttr.Sort();
00805     }
00808   RealType operator*(const PartClass& aInstance)const
00809     {
00810       return mAttr*aInstance.mAttr;
00811     }
00812  public:
00813   IndexConcreteType mPartType;
00814   DotSparseVectorClass<IndexConcreteType,AttributeClass> mAttr; 
00815 };
00816 
00817 //-----------------------------------------------------------------------------------------------
00818 
00822 class WDKDataClass
00823 {
00824  public: 
00825   friend ostream& operator<<(ostream& out, const WDKDataClass& data)
00826     {
00827       out<<"dim:"<<data.mPart.Size();
00828       out<<data.mPart;
00829       return out;
00830     }
00831   friend istream& operator>>(istream& in, WDKDataClass& data) throw(exception)
00832     {
00833       string token;
00834       int dim;
00835       in>>token;pair<string,string> key_value_token=Split(token);
00836       if (key_value_token.first!="dim") throw domain_error("Unexpected element:"+key_value_token.first);
00837       if (!IsNumber(key_value_token.second)) throw domain_error("Expected a number:"+key_value_token.second);
00838       dim=stream_cast<int>(key_value_token.second);
00839       for (int i=0;i<dim;i++)
00840         {
00841           PartClass part;
00842           in>>part;
00843           data.mPart.Insert(part.mPartType,part);
00844         }
00845       data.Optimize();
00846       return in;
00847     }
00848  public: 
00849   WDKDataClass():mNorm(0),mInverseSquaredNorm(0){}
00854   WDKDataClass(const char* aS):mNorm(1),mInverseSquaredNorm(1)
00855     {
00856       mSerialized=(string)(aS);
00857       stringstream ss(aS);
00858       ss>>(*this);
00859       ComputeNormalizationFactor();
00860     }
00863   string Serialize()
00864     {
00865       return mSerialized;
00866     }
00869   RealType operator*(const WDKDataClass& aInstance)const
00870     {
00871       if (mNorm==0 || aInstance.mNorm==0) return 0;//Case of at least one null element
00872       if (user_kernel_param.mProductNormalization)
00873         return Product(aInstance)*mInverseSquaredNorm*aInstance.mInverseSquaredNorm;
00874       else
00875         return Product(aInstance);
00876     }
00877  protected:
00880   void Optimize()
00881     {
00882       //sort by part type (selector)
00883       mPart.Sort();
00884       //optimize attributes: 
00885       //in dot product case compress attributes 
00886       //in histogram intersection kernel compute attribute bin vectors 
00887       for (list<pair<IndexConcreteType,PartClass> >::iterator it=mPart.mData.begin();it!=mPart.mData.end();)
00888         {
00889           list<pair<IndexConcreteType,PartClass> >::iterator jt=it;
00890           jt++;
00891           if (jt!=mPart.mData.end())
00892             {
00893               //if part types (selectors) are equal then optimize histograms
00894               if (it->first==jt->first)
00895                 {
00896                   it->second.Merge(jt->second);
00897                   mPart.mData.erase(jt);
00898                 }
00899               //else advance index
00900               else ++it;
00901             }
00902           else ++it;//we are one element before the end of data structure => just advance to next element and terminate
00903         }
00904     }
00907   void ComputeNormalizationFactor()
00908     {
00909       mNorm=Product(*this);
00910       mInverseSquaredNorm=1/sqrt(mNorm);
00911     }
00912  protected:
00914   RealType Product(const WDKDataClass& aInstance)const
00915     {
00916       return mPart*aInstance.mPart;
00917     }
00918  public:
00919   RealType mNorm;
00920   RealType mInverseSquaredNorm;
00921  protected:
00922   DotSparseVectorClass<IndexConcreteType,PartClass> mPart; 
00923   string mSerialized;
00924 };
00925 
00929 class UserKernelClass
00930 {
00931  public:
00933   RealType K(const WDKDataClass& aX,const WDKDataClass& aZ, const UserKernelParametersClass& aParam)
00934     {
00935       try
00936         {
00937           if (aParam.mTotalNormalization)
00938             {
00939               if (aParam.mPolyDegree!=0) 
00940                 {
00941                   if (aParam.mProductNormalization) return k(aX,aZ,aParam)*aParam.mInversePolinomialNormalizationCoefficient;
00942                   else return k(aX,aZ,aParam)/sqrt(k(aX,aX,aParam)*k(aZ,aZ,aParam));
00943                 }
00944               else //if (aParam.mExpGamma) or linear case 
00945                 return k(aX,aZ,aParam); //Note:normalization is implicit in exponential case
00946             }
00947           else return k(aX,aZ,aParam);
00948         }
00949       catch(exception e)
00950         {
00951           cerr<<e.what()<<endl;
00952           exit(1);
00953         }
00954     }
00956   RealType k(const WDKDataClass& aX,const WDKDataClass& aZ, const UserKernelParametersClass& aParam)
00957     {
00958       if (aParam.mPolyDegree!=0)
00959         {
00960           return pow(aParam.mPolyCoefficient*(aX*aZ)+aParam.mPolyOffset,aParam.mPolyDegree);
00961         }
00962       else if (aParam.mExpGamma)
00963         {
00964           if (aParam.mProductNormalization) //if data product is normalized then squared norm of data is 1
00965             return exp(-aParam.mExpGamma*(2 -2*(aX*aZ)));
00966           else
00967             return exp(-aParam.mExpGamma*(aX.mNorm -2*(aX*aZ) + aZ.mNorm));
00968         }
00969       else return (aX*aZ);
00970     }
00971 };
00972 
00973 #endif

Generated on Thu Aug 4 18:04:02 2005 for WeightedDecompositionalKernel by  doxygen 1.4.4