-
Notifications
You must be signed in to change notification settings - Fork 40
Image Dataloader with Field type #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
KimSangYeon-DGU
merged 4 commits into
mlpack:master
from
kartikdutt18:ImageDataloadersBetter
Jun 11, 2020
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
| project(augmentation) | ||
|
|
||
| option(DEBUG "DEBUG" OFF) | ||
|
|
||
| set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
| include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../") | ||
|
|
||
| set(SOURCES | ||
| augmentation.hpp | ||
| ) | ||
|
|
||
| foreach(file ${SOURCES}) | ||
| set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
| endforeach() | ||
|
|
||
| # Append sources (with directory name) to list of all models sources (used at | ||
| # the parent scope). | ||
| set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,187 @@ | ||
| /** | ||
| * @file augmentation.hpp | ||
| * @author Kartik Dutt | ||
| * | ||
| * Definition of Augmentation class for augmenting data. | ||
| * | ||
| * mlpack is free software; you may redistribute it and/or modify it under the | ||
| * terms of the 3-clause BSD license. You should have received a copy of the | ||
| * 3-clause BSD license along with mlpack. If not, see | ||
| * http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
| */ | ||
|
|
||
| #include <mlpack/methods/ann/layer/bilinear_interpolation.hpp> | ||
| #include <mlpack/core/util/to_lower.hpp> | ||
| #include <boost/regex.hpp> | ||
|
|
||
| #ifndef MODELS_AUGMENTATION_HPP | ||
| #define MODELS_AUGMENTATION_HPP | ||
|
|
||
| /** | ||
| * Augmentation class used to perform augmentations by transforming the data. | ||
| * For the list of supported augmentation, take a look at our wiki page. | ||
| * | ||
| * @code | ||
| * Augmentation augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2); | ||
| * augmentation.Transform(dataloader.TrainFeatures); | ||
| * @endcode | ||
| */ | ||
| class Augmentation | ||
| { | ||
| public: | ||
| //! Create the augmentation class object. | ||
| Augmentation() : | ||
| augmentations(std::vector<std::string>()), | ||
| augmentationProbability(0.2) | ||
| { | ||
| // Nothing to do here. | ||
| } | ||
|
|
||
| /** | ||
| * Constructor for augmentation class. | ||
| * | ||
| * @param augmentations List of strings containing one of the supported | ||
| * augmentations. | ||
| * @param augmentationProbability Probability of applying augmentation on | ||
| * the dataset. | ||
| * NOTE : This doesn't apply to augmentations | ||
| * such as resize. | ||
| */ | ||
| Augmentation(const std::vector<std::string>& augmentations, | ||
| const double augmentationProbability) : | ||
| augmentations(augmentations), | ||
| augmentationProbability(augmentationProbability) | ||
| { | ||
| // Convert strings to lower case. | ||
| for (size_t i = 0; i < augmentations.size(); i++) | ||
| mlpack::util::ToLower(augmentations[i], this->augmentations[i]); | ||
|
|
||
| // Sort the vector to place resize parameter to the front of the string. | ||
| // This prevents constant lookups for resize. | ||
| sort(this->augmentations.begin(), this->augmentations.end(), []( | ||
| std::string& str1, std::string& str2) | ||
| { | ||
| return str1.find("resize") != std::string::npos; | ||
| }); | ||
| } | ||
|
|
||
| /** | ||
| * Applies augmentation to the passed dataset. | ||
| * | ||
| * @tparam DatasetType Datatype on which augmentation will be done. | ||
| * | ||
| * @param dataset Dataset on which augmentation will be applied. | ||
| * @param datapointWidth Width of a single data point i.e. | ||
| * Since each column represents a seperate data | ||
| * point. | ||
| * @param datapointHeight Height of a single data point. | ||
| * @param datapointDepth Depth of a single data point. For one 2-dimensional | ||
| * data point, set it to 1. Defaults to 1. | ||
| */ | ||
| template<typename DatasetType> | ||
| void Transform(DatasetType& dataset, | ||
| const size_t datapointWidth, | ||
| const size_t datapointHeight, | ||
| const size_t datapointDepth = 1); | ||
|
|
||
| /** | ||
| * Applies resize transform to the entire dataset. | ||
| * | ||
| * @tparam DatasetType Datatype on which augmentation will be done. | ||
| * | ||
| * @param dataset Dataset on which augmentation will be applied. | ||
| * @param datapointWidth Width of a single data point i.e. | ||
| * Since each column represents a seperate data | ||
| * point. | ||
| * @param datapointHeight Height of a single data point. | ||
| * @param datapointDepth Depth of a single data point. For one 2-dimensional | ||
| * data point, set it to 1. Defaults to 1. | ||
| * @param augmentation String containing the transform. | ||
| */ | ||
| template<typename DatasetType> | ||
| void ResizeTransform(DatasetType& dataset, | ||
| const size_t datapointWidth, | ||
| const size_t datapointHeight, | ||
| const size_t datapointDepth, | ||
| const std::string& augmentation); | ||
|
|
||
| private: | ||
| /** | ||
| * Function to determine if augmentation has Resize function. | ||
| * | ||
| * @param augmentation Optional argument to check if a string has | ||
| * resize substring. | ||
| */ | ||
| bool HasResizeParam(const std::string& augmentation = "") | ||
| { | ||
| if (augmentation.length()) | ||
| return augmentation.find("resize") != std::string::npos; | ||
|
|
||
|
|
||
| // Search in augmentation vector. | ||
| return augmentations.size() <= 0 ? false : | ||
| augmentations[0].find("resize") != std::string::npos; | ||
| } | ||
|
|
||
| /** | ||
| * Sets size of output width and output height of the new data. | ||
| * | ||
| * @param outWidth Output width of resized data point. | ||
| * @param outHeight Output height of resized data point. | ||
| * @param augmentation String from which output width and height | ||
| * are extracted. | ||
| */ | ||
| void GetResizeParam(size_t& outWidth, | ||
| size_t& outHeight, | ||
| const std::string& augmentation) | ||
| { | ||
| if (!HasResizeParam()) | ||
| return; | ||
|
|
||
| outWidth = 0; | ||
| outHeight = 0; | ||
|
|
||
| // Use regex to find one or two numbers. If only one provided | ||
| // set output width equal to output height. | ||
| boost::regex regex{"[0-9]+"}; | ||
|
|
||
| // Create an iterator to find matches. | ||
| boost::sregex_token_iterator matches(augmentation.begin(), | ||
| augmentation.end(), regex, 0), end; | ||
|
|
||
| size_t matchesCount = std::distance(matches, end); | ||
|
|
||
| if (matchesCount == 0) | ||
| { | ||
| mlpack::Log::Fatal << "Invalid size / shape in " << | ||
| augmentation << std::endl; | ||
| } | ||
|
|
||
| if (matchesCount == 1) | ||
| { | ||
| outWidth = std::stoi(*matches); | ||
| outHeight = outWidth; | ||
| } | ||
| else | ||
| { | ||
| outWidth = std::stoi(*matches); | ||
| matches++; | ||
| outHeight = std::stoi(*matches); | ||
| } | ||
| } | ||
|
|
||
| //! Locally held augmentations and transforms that need to be applied. | ||
| std::vector<std::string> augmentations; | ||
|
|
||
| //! Locally held value of augmentation probability. | ||
| double augmentationProbability; | ||
|
|
||
| // The dataloader class should have access to internal functions of | ||
| // the dataloader. | ||
| template<typename DatasetX, typename DatasetY, class ScalerType> | ||
| friend class DataLoader; | ||
| }; | ||
|
|
||
| #include "augmentation_impl.hpp" // Include implementation. | ||
|
|
||
| #endif | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| /** | ||
| * @file augmentation_impl.hpp | ||
| * @author Kartik Dutt | ||
| * | ||
| * Implementation of Augmentation class for augmenting data. | ||
| * | ||
| * mlpack is free software; you may redistribute it and/or modify it under the | ||
| * terms of the 3-clause BSD license. You should have received a copy of the | ||
| * 3-clause BSD license along with mlpack. If not, see | ||
| * http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
| */ | ||
|
|
||
| // Incase it has not been included already. | ||
| #include "augmentation.hpp" | ||
|
|
||
| #ifndef MODELS_AUGMENTATION_IMPL_HPP | ||
| #define MODELS_AUGMENTATION_IMPL_HPP | ||
|
|
||
| template<typename DatasetType> | ||
| void Augmentation::Transform(DatasetType& dataset, | ||
| const size_t datapointWidth, | ||
| const size_t datapointHeight, | ||
| const size_t datapointDepth) | ||
| { | ||
| // Initialize the augmentation map. | ||
| std::unordered_map<std::string, void(*)(DatasetType&, | ||
| size_t, size_t, size_t, std::string&)> augmentationMap; | ||
|
|
||
| for (size_t i = 0; i < augmentations.size(); i++) | ||
| { | ||
| if (augmentationMap.count(augmentations[i])) | ||
| { | ||
| augmentationMap[augmentations[i]](dataset, datapointWidth, | ||
| datapointHeight, datapointDepth, augmentations[i]); | ||
| } | ||
| else if (this->HasResizeParam(augmentations[i])) | ||
| { | ||
| this->ResizeTransform(dataset, datapointWidth, datapointHeight, | ||
| datapointDepth, augmentations[i]); | ||
| } | ||
| else | ||
| { | ||
| mlpack::Log::Warn << "Unknown augmentation : \'" << | ||
| augmentations[i] << "\' not found!" << std::endl; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template<typename DatasetType> | ||
| void Augmentation::ResizeTransform( | ||
| DatasetType& dataset, | ||
| const size_t datapointWidth, | ||
| const size_t datapointHeight, | ||
| const size_t datapointDepth, | ||
| const std::string& augmentation) | ||
| { | ||
| size_t outputWidth = 0, outputHeight = 0; | ||
|
|
||
| // Get output width and output height. | ||
| GetResizeParam(outputWidth, outputHeight, augmentation); | ||
|
|
||
| // We will use mlpack's bilinear interpolation layer to | ||
kartikdutt18 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // resize the input. | ||
| mlpack::ann::BilinearInterpolation<DatasetType, DatasetType> resizeLayer( | ||
| datapointWidth, datapointHeight, outputWidth, outputHeight, | ||
| datapointDepth); | ||
|
|
||
| DatasetType output; | ||
| resizeLayer.Forward(dataset, output); | ||
| dataset = std::move(output); | ||
| } | ||
|
|
||
| #endif | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.