Giter Site home page Giter Site logo

tinn-dotnet's Introduction

NuGet Version NuGet Downloads Build

Tinn: Tiny Neural Network

Tinn is a tiny and dependency free neural network implementation for dotnet. It has three configurable layers: an input layer, a hidden layer and an output layer.

How to get started?

Create a neural network:

var network = new TinyNeuralNetwork(inputCount: 2, hiddenCount: 4, outputCount: 1); 

Load a data set:

// This is XOR operation example.
var input = new float[][]
{
    new []{ 1f, 1f }, // --> 0f
    new []{ 1f, 0f }, // --> 1f
    new []{ 0f, 1f }, // --> 1f
    new []{ 0f, 0f }, // --> 0f
};
var expected = new float[][]
{
    new []{ 0f }, // <-- 1f ^ 1f
    new []{ 1f }, // <-- 1f ^ 0f
    new []{ 1f }, // <-- 0f ^ 1f
    new []{ 0f }, // <-- 0f ^ 0f
};

Train the network until a desired accuracy is achieved:

for (int i = 0; i < input.Length; i++)
{
    network.Train(input[i], expected[i], 1f);
}
// Note: you will probably have to loop this for a few times until network improves.

Try to predict some values:

var prediction = network.Predict(new [] { 1f, 1f });  
// Will return probability close to 0f, since 1 ^ 1 = 0.

For more examples see the examples directory and automated tests.


The original library was written by glouw in C.

tinn-dotnet's People

Contributors

dlidstrom avatar lawrence-laz avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

tinn-dotnet's Issues

Some fixes/updates

  • LossFunctionPartialDerivative should have parameters switched around
  • Save/Load should be culture invariant
  • learning phase does not have to compute total error on every backprop call
  • example program minor updates to show accuracy / fix obsolete code
  • reserve 10% of training data for verification, this shows network learns to about 95-96% (i.e. not 99%)
  • changed some things according to personal preference (sorry!)
  • biases are never updated but I'm not sure how to do it
using System.Globalization;

namespace Tinn;

/// <summary>
/// A tiny neural network with one hidden layer and configurable parameters.
/// </summary>
public class TinyNeuralNetwork
{
    internal float[] _weights;
    internal float[] _biases;
    internal float[] _hiddenLayer;
    internal float[] _outputLayer;
    internal int _inputCount;
    internal Random _random;

    /// <summary>
    /// Creates an instance of an untrained neural network.
    /// </summary>
    /// <param name="inputCount">Number of inputs or features.</param>
    /// <param name="hiddentCount">Number of hidden neurons in a hidden layer.</param>
    /// <param name="outputCount">Number of outputs or classes.</param>
    /// <param name="seed">A seed for random generator to produce predictable results.</param>
    public TinyNeuralNetwork(int inputCount, int hiddentCount, int outputCount, int seed = default)
    {
        _random = new Random(seed);
        _inputCount = inputCount;
        _weights = Enumerable.Range(0, hiddentCount * (inputCount + outputCount)).Select(_ => (float)_random.NextDouble() - 0.5f).ToArray();
        _biases = Enumerable.Range(0, 2).Select(_ => (float)_random.NextDouble() - 0.5f).ToArray(); // Tinn only supports one hidden layer so there are two biases.
        _hiddenLayer = new float[hiddentCount];
        _outputLayer = new float[outputCount];
    }

    private TinyNeuralNetwork(float[] weights, float[] biases, float[] hiddenLayer, float[] outputLayer, int inputCount, int seed)
    {
        _weights = weights;
        _biases = biases;
        _hiddenLayer = hiddenLayer;
        _outputLayer = outputLayer;
        _inputCount = inputCount;
        _random = new Random(seed);
    }

    /// <summary>
    /// Loads a pretrained neural network from a `*.tinn` file.
    /// </summary>
    /// <param name="path">An absolute or a relative path to the `*.tinn` file.</param>
    /// <param name="seed">A seed for random generator to produce predictable results.</param>
    /// <returns>An instance of a pretrained <see cref="TinyNeuralNetwork"/>.</returns>
    public static TinyNeuralNetwork Load(string path, int seed = default)
    {
        using StreamReader reader = new(path);
        string metaData = ReadLine();
        var counts = metaData.Split(' ').Select(int.Parse).ToArray();
        var inputCount = counts[0];
        var hiddenCount = counts[1];
        var outputCount = counts[2];

        var weights = new float[hiddenCount * (inputCount + outputCount)];
        var biases = new float[2];
        var hiddenLayer = new float[hiddenCount];
        var outputLayer = new float[outputCount];
        var biasCount = 2;
        for (var i = 0; i < biasCount; i++)
        {
            biases[i] = float.Parse(ReadLine(), CultureInfo.InvariantCulture);
        }

        for (int i = 0; i < weights.Length; i++)
        {
            weights[i] = float.Parse(ReadLine(), CultureInfo.InvariantCulture);
        }

        TinyNeuralNetwork network = new(weights, biases, hiddenLayer, outputLayer, inputCount, seed);
        return network;

        string ReadLine()
        {
            return reader.ReadLine() ?? throw new Exception("invalid file");
        }
    }

    /// <summary>
    /// Predicts outputs from a given input.
    /// </summary>
    /// <param name="input">A float array matching the length of input count.</param>
    /// <returns>An array of predicted probabilities for each class. </returns>
    public float[] Predict(float[] input)
    {
        PropagateForward(input);
        return _outputLayer;
    }

    /// <summary>
    /// Trains neural network on a single data record.
    /// </summary>
    /// <param name="input">Records input or feature values.</param>
    /// <param name="expectedOutput">Actual record's class in a categorical format.</param>
    /// <param name="learningRate">Learning rate of a training.</param>
    public void Train(float[] input, float[] expectedOutput, float learningRate)
    {
        PropagateForward(input);
        PropogateBackward(input, expectedOutput, learningRate);
    }

    /// <summary>
    /// Get total error
    /// </summary>
    /// <param name="expectedOutput">Actual record's class in a categorical format.</param>
    /// <returns>Aggregated error value indicating how far off the neural network is on the training data set.</returns>
    public float GetTotalError(float[] expectedOutput)
    {
        return GetTotalError(expectedOutput, _outputLayer);
    }

    /// <summary>
    /// Saves a trained neural network to a `*.tinn` file.
    /// </summary>
    /// <param name="path">An absolute or a relative path to the `*.tinn` file.</param>
    public void Save(string path)
    {
        using StreamWriter writer = new FormattingStreamWriter(path, CultureInfo.InvariantCulture);
        writer.WriteLine($"{_inputCount} {_hiddenLayer.Length} {_outputLayer.Length}");
        foreach (float bias in _biases)
            writer.WriteLine(bias);

        foreach (float weight in _weights)
            writer.WriteLine(weight);
    }

    private void PropagateForward(float[] input)
    {
        // Calculate hidden layer neuron values.
        for (var i = 0; i < _hiddenLayer.Length; i++)
        {
            var sum = 0.0f;
            for (var j = 0; j < _inputCount; j++)
                sum += input[j] * _weights[i * _inputCount + j];

            _hiddenLayer[i] = ActivationFunction(sum + _biases[0]);
        }

        // Calculate output layer neuron values.
        for (int i = 0; i < _outputLayer.Length; i++)
        {
            var sum = 0.0f;

            for (int j = 0; j < _hiddenLayer.Length; j++)
                sum += _hiddenLayer[j] * _weights[(_hiddenLayer.Length * _inputCount) + i * _hiddenLayer.Length + j];

            _outputLayer[i] = ActivationFunction(sum + _biases[1]);
        }
    }

    private void PropogateBackward(float[] input, float[] expectedOutput, float learningRate)
    {
        for (var i = 0; i < _hiddenLayer.Length; i++)
        {
            var sum = 0.0f;

            // Calculate total error change with respect to output.
            for (var j = 0; j < _outputLayer.Length; j++)
            {
                float a = LossFunctionPartialDerivative(_outputLayer[j], expectedOutput[j]);
                float b = ActivationFunctionPartialDerivative(_outputLayer[j]);
                sum += a * b * _weights[(_hiddenLayer.Length * _inputCount) + j * _hiddenLayer.Length + i];

                // Correct weights in hidden to output layer.
                _weights[(_hiddenLayer.Length * _inputCount) + j * _hiddenLayer.Length + i] -= learningRate * a * b * _hiddenLayer[i];
            }

            // Correct weights in input to hidden layer.
            for (int j = 0; j < _inputCount; j++)
            {
                _weights[i * _inputCount + j] -= learningRate * sum * ActivationFunctionPartialDerivative(_hiddenLayer[i]) * input[j];
            }
        }
    }

    private static float ActivationFunction(float value)
    {
        return 1.0f / (1.0f + (float)Math.Exp(-value));
    }

    private static float ActivationFunctionPartialDerivative(float value)
    {
        return value * (1f - value);
    }

    private static float LossFunction(float expected, float actual)
    {
        return 0.5f * (expected - actual) * (expected - actual);
    }

    private static float LossFunctionPartialDerivative(float actual, float expected)
    {
        return actual - expected;
    }

    private static float GetTotalError(float[] expected, float[] actual)
    {
        float totalError = expected.Zip(actual, (e, a) => LossFunction(e, a)).Sum();
        return totalError;
    }

    private class FormattingStreamWriter : StreamWriter
    {
        private readonly IFormatProvider _formatProvider;

        public FormattingStreamWriter(string path, IFormatProvider formatProvider)
            : base(path)
        {
            _formatProvider = formatProvider;
        }

        public override IFormatProvider FormatProvider => _formatProvider;
    }
}
using System.Globalization;
using ShellProgressBar;
using Tinn;

const string datasetUri = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data";
const string dataSetFileName = "semeion.data";
const int inputCount = 256;
const int hiddenCount = 28;
const int outputCount = 10;

// reserve ~10% of data for verification
const int verifyCount = 150;

const int learningIterations = 10;
float learningRate = 0.1f;
const float learningRateDecay = 0.95f;
var random = new Random(0);

if (File.Exists(dataSetFileName) == false)
{
    Console.WriteLine("Downloading MNIST dataset...");
    using HttpClient client = new();
    using Stream file = File.Create(dataSetFileName);
    Stream stream = await client.GetStreamAsync(datasetUri);
    stream.CopyTo(file);
    Console.WriteLine("Download completed.");
}

(float[] Input, float[] Output)[] allData = File.ReadAllLines(dataSetFileName)
    .Select(line => line.Split(" ").Select(x => float.Parse(x, CultureInfo.InvariantCulture)))
    .Select(number => (
        Input: number.Take(inputCount).ToArray(),
        Output: number.Skip(inputCount).Take(outputCount).ToArray())
    )
    .ToArray();

(float[] Input, float[] Output)[] learningData = allData.Skip(verifyCount).ToArray();
(float[] Input, float[] Output)[] verifyData = allData.Take(verifyCount).ToArray();

var network = new TinyNeuralNetwork(inputCount, hiddenCount, outputCount);
var progress = new ProgressBar(learningIterations, "Training...");

string currentAccuracy = "";
for (var i = 0; i < learningIterations; i++)
{
    using ChildProgressBar child = progress.Spawn(learningData.Length, "iteration " + i, new ProgressBarOptions { CollapseWhenFinished = true });
    foreach ((float[] Input, float[] Output, int n) in learningData.Select(((float[] i, float[] o) data, int n) => (data.i, data.o, n)))
    {
        network.Train(Input, Output, learningRate);
        if (n == learningData.Length - 1)
        {
            currentAccuracy = ComputeAccuracy(verifyData, network);
        }

        child.Tick();
    }

    Shuffle(learningData);
    learningRate *= learningRateDecay;
    progress.Tick(currentAccuracy);
}

network.Save("network.tinn");
currentAccuracy = ComputeAccuracy(verifyData, network);
Console.WriteLine(currentAccuracy);

// Used for shuffling data set in between training iterations.
void Shuffle<T>(T[] array)
{
    for (int i = 0; i < array.Length; i++)
    {
        var j = random.Next(array.Length);
        (array[i], array[j]) = (array[j], array[i]);
    }
}

string ComputeAccuracy((float[] Input, float[] Output)[] subset, TinyNeuralNetwork network)
{
    int[] predictedNumbers = subset
        .Select(x => network.Predict(x.Input))
        .Select(f => f.Select((n, i) => (n, i)).Max().i)
        .ToArray();

    int[] actualNumbers = subset
        .Select(record => record.Output.Select((n, i) => (n, i)).Max().i)
        .ToArray();

    double correctlyGuessed = predictedNumbers.Zip(actualNumbers, (l, r) => l == r ? 1.0 : 0.0).Sum();
    double accuracy = correctlyGuessed / actualNumbers.Length;
    return $"Achieved {accuracy:P2} accuracy.";
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.