Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@ xcode*
.DS_Store
.idea
cmake-build-*
data/*
*.csv
*.tar
*.zip
*.tar.gz
*.xml
*.jpeg
*.jpg
*.png
*.txt
.travis/configs.hpp
Testing/*
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ find_package(Boost 1.49
COMPONENTS
filesystem
system
regex
program_options
serialization
unit_test_framework
Expand Down
19 changes: 19 additions & 0 deletions augmentation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR)
project(augmentation)

option(DEBUG "DEBUG" OFF)

set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../")

set(SOURCES
augmentation.hpp
)

foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()

# Append sources (with directory name) to list of all models sources (used at
# the parent scope).
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE)
187 changes: 187 additions & 0 deletions augmentation/augmentation.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/**
* @file augmentation.hpp
* @author Kartik Dutt
*
* Definition of Augmentation class for augmenting data.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#include <mlpack/methods/ann/layer/bilinear_interpolation.hpp>
#include <mlpack/core/util/to_lower.hpp>
#include <boost/regex.hpp>

#ifndef MODELS_AUGMENTATION_HPP
#define MODELS_AUGMENTATION_HPP

/**
* Augmentation class used to perform augmentations by transforming the data.
* For the list of supported augmentation, take a look at our wiki page.
*
* @code
* Augmentation augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2);
* augmentation.Transform(dataloader.TrainFeatures);
* @endcode
*/
class Augmentation
{
public:
//! Create the augmentation class object.
Augmentation() :
augmentations(std::vector<std::string>()),
augmentationProbability(0.2)
{
// Nothing to do here.
}

/**
* Constructor for augmentation class.
*
* @param augmentations List of strings containing one of the supported
* augmentations.
* @param augmentationProbability Probability of applying augmentation on
* the dataset.
* NOTE : This doesn't apply to augmentations
* such as resize.
*/
Augmentation(const std::vector<std::string>& augmentations,
const double augmentationProbability) :
augmentations(augmentations),
augmentationProbability(augmentationProbability)
{
// Convert strings to lower case.
for (size_t i = 0; i < augmentations.size(); i++)
mlpack::util::ToLower(augmentations[i], this->augmentations[i]);

// Sort the vector to place resize parameter to the front of the string.
// This prevents constant lookups for resize.
sort(this->augmentations.begin(), this->augmentations.end(), [](
std::string& str1, std::string& str2)
{
return str1.find("resize") != std::string::npos;
});
}

/**
* Applies augmentation to the passed dataset.
*
* @tparam DatasetType Datatype on which augmentation will be done.
*
* @param dataset Dataset on which augmentation will be applied.
* @param datapointWidth Width of a single data point i.e.
* Since each column represents a seperate data
* point.
* @param datapointHeight Height of a single data point.
* @param datapointDepth Depth of a single data point. For one 2-dimensional
* data point, set it to 1. Defaults to 1.
*/
template<typename DatasetType>
void Transform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth = 1);

/**
* Applies resize transform to the entire dataset.
*
* @tparam DatasetType Datatype on which augmentation will be done.
*
* @param dataset Dataset on which augmentation will be applied.
* @param datapointWidth Width of a single data point i.e.
* Since each column represents a seperate data
* point.
* @param datapointHeight Height of a single data point.
* @param datapointDepth Depth of a single data point. For one 2-dimensional
* data point, set it to 1. Defaults to 1.
* @param augmentation String containing the transform.
*/
template<typename DatasetType>
void ResizeTransform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation);

private:
/**
* Function to determine if augmentation has Resize function.
*
* @param augmentation Optional argument to check if a string has
* resize substring.
*/
bool HasResizeParam(const std::string& augmentation = "")
{
if (augmentation.length())
return augmentation.find("resize") != std::string::npos;


// Search in augmentation vector.
return augmentations.size() <= 0 ? false :
augmentations[0].find("resize") != std::string::npos;
}

/**
* Sets size of output width and output height of the new data.
*
* @param outWidth Output width of resized data point.
* @param outHeight Output height of resized data point.
* @param augmentation String from which output width and height
* are extracted.
*/
void GetResizeParam(size_t& outWidth,
size_t& outHeight,
const std::string& augmentation)
{
if (!HasResizeParam())
return;

outWidth = 0;
outHeight = 0;

// Use regex to find one or two numbers. If only one provided
// set output width equal to output height.
boost::regex regex{"[0-9]+"};

// Create an iterator to find matches.
boost::sregex_token_iterator matches(augmentation.begin(),
augmentation.end(), regex, 0), end;

size_t matchesCount = std::distance(matches, end);

if (matchesCount == 0)
{
mlpack::Log::Fatal << "Invalid size / shape in " <<
augmentation << std::endl;
}

if (matchesCount == 1)
{
outWidth = std::stoi(*matches);
outHeight = outWidth;
}
else
{
outWidth = std::stoi(*matches);
matches++;
outHeight = std::stoi(*matches);
}
}

//! Locally held augmentations and transforms that need to be applied.
std::vector<std::string> augmentations;

//! Locally held value of augmentation probability.
double augmentationProbability;

// The dataloader class should have access to internal functions of
// the dataloader.
template<typename DatasetX, typename DatasetY, class ScalerType>
friend class DataLoader;
};

#include "augmentation_impl.hpp" // Include implementation.

#endif
73 changes: 73 additions & 0 deletions augmentation/augmentation_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/**
* @file augmentation_impl.hpp
* @author Kartik Dutt
*
* Implementation of Augmentation class for augmenting data.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

// Incase it has not been included already.
#include "augmentation.hpp"

#ifndef MODELS_AUGMENTATION_IMPL_HPP
#define MODELS_AUGMENTATION_IMPL_HPP

template<typename DatasetType>
void Augmentation::Transform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth)
{
// Initialize the augmentation map.
std::unordered_map<std::string, void(*)(DatasetType&,
size_t, size_t, size_t, std::string&)> augmentationMap;

for (size_t i = 0; i < augmentations.size(); i++)
{
if (augmentationMap.count(augmentations[i]))
{
augmentationMap[augmentations[i]](dataset, datapointWidth,
datapointHeight, datapointDepth, augmentations[i]);
}
else if (this->HasResizeParam(augmentations[i]))
{
this->ResizeTransform(dataset, datapointWidth, datapointHeight,
datapointDepth, augmentations[i]);
}
else
{
mlpack::Log::Warn << "Unknown augmentation : \'" <<
augmentations[i] << "\' not found!" << std::endl;
}
}
}

template<typename DatasetType>
void Augmentation::ResizeTransform(
DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation)
{
size_t outputWidth = 0, outputHeight = 0;

// Get output width and output height.
GetResizeParam(outputWidth, outputHeight, augmentation);

// We will use mlpack's bilinear interpolation layer to
// resize the input.
mlpack::ann::BilinearInterpolation<DatasetType, DatasetType> resizeLayer(
datapointWidth, datapointHeight, outputWidth, outputHeight,
datapointDepth);

DatasetType output;
resizeLayer.Forward(dataset, output);
dataset = std::move(output);
}

#endif
Loading