Find a linear function that represents a set of data points in Java

By xngo on February 21, 2019

To find a linear function(a straight line) that represents a set of data points, you have to implement Simple Linear Regression. Here is an example.

/**
 * Simple Linear Regression: Find a linear function that represents a set of data points.
 * @author Xuan Ngo 
 */
import java.util.ArrayList;
import static java.lang.Math.pow;
import java.lang.RuntimeException;
 
public class SimpleLinearRegression
{
  private ArrayList<Double> m_aX = new ArrayList<Double>();
  private ArrayList<Double> m_aY = new ArrayList<Double>();
 
  private ArrayList<Double> m_aXX = null;
  private ArrayList<Double> m_aXY = null;
 
  private double m_dSumOfXs = 0;
  private double m_dSumOfYs = 0;
  private double m_dSumOfXXs = 0;
  private double m_dSumOfXYs = 0;
 
  private double m_dSlope = 0;
 
  public static void main(String[] args)
  {
    ArrayList<Double> aX = new ArrayList<Double>();
    aX.add(new Double(60));
    aX.add(new Double(61));
    aX.add(new Double(62));
    aX.add(new Double(63));
    aX.add(new Double(65));
    ArrayList<Double> aY = new ArrayList<Double>();
    aY.add(new Double(3.1));
    aY.add(new Double(3.6));
    aY.add(new Double(3.8));
    aY.add(new Double(4));
    aY.add(new Double(4.1));
 
    /* // Bad case
    ArrayList<Double> aX = new ArrayList<Double>();
    aX.add(new Double(1));
    aX.add(new Double(4));
    aX.add(new Double(6));
    aX.add(new Double(13));
    aX.add(new Double(6));
    ArrayList<Double> aY = new ArrayList<Double>();
    aY.add(new Double(3));
    aY.add(new Double(1));
    aY.add(new Double(2));
    aY.add(new Double(6));
    aY.add(new Double(9));
    */
 
    SimpleLinearRegression slr = new SimpleLinearRegression(aX, aY);
 
    System.out.println("Slope = "+slr.getSlope()+"  Intercept = "+slr.getIntercept());
    System.out.println("y = "+slr.getSlope()+"x + ("+slr.getIntercept()+")");
 
  }
  public SimpleLinearRegression(final ArrayList<Double> aX, final ArrayList<Double> aY)
  {
    this.m_aX = aX;
    this.m_aY = aY;
 
    // Prepare sigma values.
    this.m_dSumOfXs = this.sum(aX);
    this.m_dSumOfYs = this.sum(aY);
 
    this.calculateXX();
    this.m_dSumOfXXs = this.sum(this.m_aXX);
 
    this.calculateXY();
    this.m_dSumOfXYs = this.sum(this.m_aXY);
  }
 
  /**
   * Slope = (NΣXY - (ΣX)(ΣY)) / (NΣ(X^2) - (ΣX)^2)
   * where, N = number of values.
   * @return
   */
  public double getSlope()
  {
    final int iNumOfValues = this.m_aX.size();
    final double dSlope = ((iNumOfValues*this.m_dSumOfXYs) - (this.m_dSumOfXs*this.m_dSumOfYs)) / ((iNumOfValues*this.m_dSumOfXXs) - pow(this.m_dSumOfXs, 2.0));
 
    this.m_dSlope = dSlope; 
    return this.m_dSlope;
  }
  /**
   * Intercept = (ΣY - b(ΣX)) / N 
   * where, N = number of values.
   * @return
   */
  public double getIntercept()
  {
    // If slope is 0 throw an exception.
    if(this.m_dSlope==0)
      throw new RuntimeException("Run this.getSlope() to calculate the slope first.");
 
    final int iNumOfValues = this.m_aX.size();
    final double dConstant = (this.m_dSumOfYs - (this.m_dSlope*this.m_dSumOfXs)) / iNumOfValues;
 
    return dConstant;
  }
  /**
   * Sum all values in the array list.
   * @param aD
   * @return
   */
  private double sum(ArrayList<Double> aD)
  {
    double dSum=0;
    for(int i=0; i<aD.size(); i++)
    {
      dSum+=aD.get(i).doubleValue();
    }
    return dSum;
  }
 
  /**
   * Calculate X*Y for all values.
   */
  private void calculateXY()
  {
    this.m_aXY = new ArrayList<Double>();
 
    for(int i=0; i<this.m_aX.size(); i++)
    {
      double x = this.m_aX.get(i).doubleValue();
      double y = this.m_aY.get(i).doubleValue();
 
      this.m_aXY.add(new Double(x*y));
    }
  }
  /**
   * Calculate X*X for all values.
   */
  private void calculateXX()
  {
    this.m_aXX = new ArrayList<Double>();
 
    for(int i=0; i<this.m_aX.size(); i++)
    {
      double x = this.m_aX.get(i).doubleValue();
      this.m_aXX.add(new Double(x*x));
    }    
  }
}
/**
 * Test cases of SimpleLinearRegression class
 * @author Xuan Ngo
 */
import java.util.ArrayList;
import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
 
public class SimpleLinearRegressionTest
{
  @Test
  public void slrPositiveSlopeTest()
  {
    ArrayList<Double> aX = new ArrayList<Double>();
    ArrayList<Double> aY = new ArrayList<Double>();
    aX.add(new Double(0.0));  aY.add(new Double(4.0));
    aX.add(new Double(3.0));  aY.add(new Double(8.0));
 
    SimpleLinearRegression slr = new SimpleLinearRegression(aX, aY);
 
    assertEquals(slr.getSlope(), (8.0-4.0)/(3.0-0.0));
    assertEquals(slr.getIntercept(), 4.0);
  }
 
  @Test
  public void slrNegativeSlopeTest()
  {
    ArrayList<Double> aX = new ArrayList<Double>();
    ArrayList<Double> aY = new ArrayList<Double>();
    aX.add(new Double(0.0));  aY.add(new Double(8.0));
    aX.add(new Double(3.0));  aY.add(new Double(4.0));
 
    SimpleLinearRegression slr = new SimpleLinearRegression(aX, aY);
 
    assertEquals(slr.getSlope(), (4.0-8.0)/(3.0-0.0));
    assertEquals(slr.getIntercept(), 8.0);
  }  
}

Reference

About the author

Xuan Ngo is the founder of OpenWritings.net. He currently lives in Montreal, Canada. He loves to write about programming and open source subjects.