Objective
We have two DataFrames containing information about Venture Capitalists and the start-ups they have funded. The aim is to find Venture Capitalists that have funded start-ups with an average funding above a certain limit, where the limit is unique for each Venture Capitalist.
Task
Write a PySpark function that combines these DataFrames and returns the Venture Capitalists whose funded start-ups have an average funding amount strictly greater than their corresponding funding_limit.
The avg_funding field should contain the average funding provided by the Venture Capitalist to the startups, cast to a Float. Save your resulting DataFrame as result_df. Ensure the output contains exactly 3 columns matching the Output Schema order, and sort the final output by vc_id in ascending order.
File Path
- VC Dataset:
/home/interview/venture_capitalist.csv
- Startups Dataset:
/home/interview/funded_startups.csv
- Starter script:
/home/interview/vc_funding.py
Schema
venture_capitalist.csv
| Column Name |
Type |
| vc_id |
string |
| vc_name |
string |
| funding_limit |
float |
funded_startups.csv
| Column Name |
Type |
| startup_id |
string |
| startup_name |
string |
| vc_id |
string |
| funding |
float |
Expected Output Schema
| Column Name |
Type |
| vc_id |
string |
| vc_name |
string |
| avg_funding |
float |
Example
Given this sample input:
venture_capitalist_df
| vc_id |
vc_name |
funding_limit |
| VC1 |
VC Firm 1 |
1.5 |
| VC2 |
VC Firm 2 |
2.0 |
| VC3 |
VC Firm 3 |
1.75 |
| VC4 |
VC Firm 4 |
2.5 |
funded_startups_df
| startup_id |
startup_name |
vc_id |
funding |
| S1 |
Startup 1 |
VC1 |
2.0 |
| S2 |
Startup 2 |
VC1 |
1.0 |
| S3 |
Startup 3 |
VC2 |
2.5 |
| S4 |
Startup 4 |
VC2 |
2.0 |
| S5 |
Startup 5 |
VC3 |
1.8 |
| S6 |
Startup 6 |
VC3 |
1.7 |
| S7 |
Startup 7 |
VC4 |
3.0 |
| S8 |
Startup 8 |
VC4 |
2.0 |
The expected output would be:
| vc_id |
vc_name |
avg_funding |
| VC2 |
VC Firm 2 |
2.25 |
(Explanation: VC2 funded S3 (2.5) and S4 (2.0). The average is 2.25. Since 2.25 > 2.0 (VC2's limit), VC2 is included. VC1's average is 1.5, which is not strictly greater than 1.5, so it is excluded).
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName("PrepareshSpark").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
venture_capitalist_df = spark.read.csv("/home/interview/venture_capitalist.csv", header=True, inferSchema=True)
funded_startups_df = spark.read.csv("/home/interview/funded_startups.csv", header=True, inferSchema=True)
# Step 1: Join the DataFrames on the shared vc_id
joined_df = venture_capitalist_df.join(funded_startups_df, on="vc_id", how="inner")
# Step 2: Group by the VC details (including the limit) to calculate the average funding
agg_df = joined_df.groupBy("vc_id", "vc_name", "funding_limit").agg(
F.avg("funding").cast("float").alias("avg_funding")
)
# Step 3: Filter for VCs where their calculated average exceeds their specific limit
result_df = agg_df.filter(F.col("avg_funding") > F.col("funding_limit"))
# Step 4: Select exactly the 3 requested columns and sort deterministically
result_df = result_df.select("vc_id", "vc_name", "avg_funding").orderBy("vc_id")
# --- Do not edit below this line ---
result_df.coalesce(1).write.csv("/home/interview/output", header=True, mode="overwrite")
spark.stop()
Explanation
Step 1: Joining the Datasets
joined_df = venture_capitalist_df.join(funded_startups_df, on="vc_id", how="inner")
To compare a startup's funding against its VC's specific limit, we first need to bring that limit directly into the startup dataset. Because vc_id acts as the primary key in the VC table and the foreign key in the startups table, an inner join perfectly maps the funding limits to the individual startup funding rounds.
Step 2: Grouping and Aggregating
agg_df = joined_df.groupBy("vc_id", "vc_name", "funding_limit").agg(
F.avg("funding").cast("float").alias("avg_funding")
)
To calculate the average funding per VC, we group the joined DataFrame. Notice that we include funding_limit in the .groupBy() list alongside vc_id and vc_name. Because a specific VC's limit is constant for all their rows, grouping by it doesn't change the bucket sizes, but it ensures the funding_limit column safely survives the aggregation so we can use it in the next step. We use F.avg() to calculate the mean and .cast("float") to match the target schema.
Step 3: Filtering on Dynamic Columns
result_df = agg_df.filter(F.col("avg_funding") > F.col("funding_limit"))
Normally, filters look like F.col("price") > 100, comparing a column to a static number. However, PySpark's .filter() can easily compare two dynamic columns against each other row-by-row. We filter where the newly calculated avg_funding is strictly greater than the VC's specific funding_limit.
Step 4: Formatting and Sorting the Output
result_df = result_df.select("vc_id", "vc_name", "avg_funding").orderBy("vc_id")
Finally, we drop the funding_limit column by explicitly selecting only the three columns requested in the strict Output Schema. We then chain .orderBy("vc_id") to ensure the final report is sorted deterministically.